From c64e657632747de6230eb369ef3db824d9a56d54 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sat, 18 Jan 2025 14:57:31 -0800 Subject: [PATCH] PEP585 update - torch/distributed/fsdp (#145162) See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145162 Approved by: https://github.com/bobrenjc93 --- torch/distributed/fsdp/_common_utils.py | 54 ++-- torch/distributed/fsdp/_debug_utils.py | 12 +- torch/distributed/fsdp/_dynamo_utils.py | 4 +- torch/distributed/fsdp/_exec_order_utils.py | 22 +- torch/distributed/fsdp/_flat_param.py | 110 ++++---- torch/distributed/fsdp/_fsdp_extensions.py | 6 +- .../fsdp/_fully_shard/_fsdp_collectives.py | 60 ++--- .../fsdp/_fully_shard/_fsdp_common.py | 4 +- .../fsdp/_fully_shard/_fsdp_init.py | 24 +- .../fsdp/_fully_shard/_fsdp_param.py | 21 +- .../fsdp/_fully_shard/_fsdp_param_group.py | 34 +-- .../fsdp/_fully_shard/_fsdp_state.py | 27 +- .../fsdp/_fully_shard/_fully_shard.py | 22 +- torch/distributed/fsdp/_init_utils.py | 99 +++---- torch/distributed/fsdp/_limiter_utils.py | 4 +- torch/distributed/fsdp/_optim_utils.py | 255 +++++++++--------- torch/distributed/fsdp/_runtime_utils.py | 42 +-- torch/distributed/fsdp/_state_dict_utils.py | 39 +-- torch/distributed/fsdp/_trace_utils.py | 22 +- torch/distributed/fsdp/_traversal_utils.py | 17 +- .../distributed/fsdp/_unshard_param_utils.py | 3 +- torch/distributed/fsdp/_wrap_utils.py | 32 +-- torch/distributed/fsdp/api.py | 5 +- .../fsdp/fully_sharded_data_parallel.py | 73 +++-- torch/distributed/fsdp/sharded_grad_scaler.py | 13 +- torch/distributed/fsdp/wrap.py | 69 ++--- 26 files changed, 497 insertions(+), 576 deletions(-) diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index be99b8cb7f9..63ace36da62 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -6,22 +6,10 @@ import logging import traceback import warnings import weakref +from collections.abc import Generator, Iterable from enum import auto, Enum from functools import partial -from typing import ( - Any, - Callable, - cast, - Dict, - Generator, - Iterable, - List, - no_type_check, - Optional, - Set, - Type, - TYPE_CHECKING, -) +from typing import Any, Callable, cast, no_type_check, Optional, TYPE_CHECKING import torch import torch.distributed as dist @@ -118,10 +106,10 @@ class _FSDPState(_State): def __init__(self) -> None: # TODO: Move all the attributes to this class to enable typing for # FSDP/fully_shard. - self._ignored_modules: Set[nn.Module] = set() - self._ignored_params: Set[nn.Parameter] = set() + self._ignored_modules: set[nn.Module] = set() + self._ignored_params: set[nn.Parameter] = set() # Buffer names are cleaned (without wrapper prefixes) - self._ignored_buffer_names: Set[str] = set() + self._ignored_buffer_names: set[str] = set() self.process_group: Optional[dist.ProcessGroup] = None self.rank: int = -1 self.world_size: int = -1 @@ -129,13 +117,13 @@ class _FSDPState(_State): self.sharding_strategy = ShardingStrategy.FULL_SHARD self._use_orig_params: bool = False self.training_state = TrainingState.IDLE - self._unshard_params_ctx: Dict[nn.Module, Generator] = {} + self._unshard_params_ctx: dict[nn.Module, Generator] = {} self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT self._state_dict_config: StateDictConfig = FullStateDictConfig() self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig() self._is_root: Optional[bool] = None self._handle: Optional[flat_param_file.FlatParamHandle] = None - self._fully_sharded_module_to_handle: Dict[ + self._fully_sharded_module_to_handle: dict[ nn.Module, Optional[flat_param_file.FlatParamHandle] ] = {} self.compute_device: Optional[torch.device] = None @@ -149,8 +137,8 @@ class _FSDPState(_State): self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle() # All following attributes should only be used for root states: # Save these static lists to avoid the repeated tree traversals - self._all_fsdp_states: List[_FSDPState] = [] - self._all_handles: List[flat_param_file.FlatParamHandle] = [] + self._all_fsdp_states: list[_FSDPState] = [] + self._all_handles: list[flat_param_file.FlatParamHandle] = [] self._fsdp_extension: Optional[FSDPExtensions] = None @@ -262,7 +250,7 @@ def _is_fsdp_flattened(tensor: torch.Tensor) -> bool: def _named_parameters_with_duplicates( module: nn.Module, **kwargs: Any -) -> List[tuple[str, nn.Parameter]]: +) -> list[tuple[str, nn.Parameter]]: """ This API is required as some modules overwrite `named_parameters()` but do not support `remove_duplicate`. @@ -282,7 +270,7 @@ def _named_parameters_with_duplicates( def _get_param_to_fqns( model: torch.nn.Module, dedup_shared_params: bool = True, -) -> Dict[nn.Parameter, List[str]]: +) -> dict[nn.Parameter, list[str]]: """ Constructs a mapping from parameter to a list of its \"canonical\" FQNs. Here, we use canonical to mean the fully-qualified name assigned to the parameter @@ -352,7 +340,7 @@ def _get_param_to_fqns( def return_fn(param_to_fqns): return param_to_fqns - param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {} + param_to_unflat_param_names: dict[torch.nn.Parameter, list[str]] = {} return _apply_to_modules( model, module_fn, @@ -377,7 +365,7 @@ def _log_post_backward_hook( @no_type_check def _get_handle_fqns_from_root( state: _FSDPState, handle: "FlatParamHandle" -) -> Optional[List[str]]: +) -> Optional[list[str]]: if handle is None: return None param_to_fqn = state._exec_order_data.param_to_fqn @@ -392,7 +380,7 @@ def _apply_to_modules( root_module: torch.nn.Module, module_fn: Callable, return_fn: Callable, - filter_fqns: Optional[List[str]] = None, + filter_fqns: Optional[list[str]] = None, *args, **kwargs, ): @@ -443,7 +431,7 @@ def _apply_to_modules( @no_type_check def _assert_in_training_states( state: _FSDPState, - training_states: List[TrainingState], + training_states: list[TrainingState], ) -> None: """Asserts that FSDP is in the states ``_training_states``.""" # Raise a `ValueError` instead of using `assert` to ensure that these @@ -462,7 +450,7 @@ def _assert_in_training_states( raise ValueError(msg) -def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]: +def _get_root_modules(modules: set[nn.Module]) -> set[nn.Module]: """ Returns: Set[nn.Module]: The subset of ``modules`` that are root modules (i.e. @@ -470,7 +458,7 @@ def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]: words, these are the modules in ``modules`` that are not the child of any other module in ``modules``. """ - root_modules: Set[nn.Module] = set() + root_modules: set[nn.Module] = set() module_to_submodules = {module: set(module.modules()) for module in modules} for candidate_module in modules: is_root_module = True @@ -488,12 +476,12 @@ def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]: def _override_module_mixed_precision( root: torch.nn.Module, - module_classes_to_override: Iterable[Type[nn.Module]], - wrap_override_dict: Dict[str, Any] = {"mixed_precision": None}, # noqa: B006 -) -> Set[Type[nn.Module]]: + module_classes_to_override: Iterable[type[nn.Module]], + wrap_override_dict: dict[str, Any] = {"mixed_precision": None}, # noqa: B006 +) -> set[type[nn.Module]]: module_classes_to_override = tuple(set(module_classes_to_override)) # Return a set of the actually overridden module classes - overridden_module_classes: Set[Type[nn.Module]] = set() + overridden_module_classes: set[type[nn.Module]] = set() for mod in root.modules(): if isinstance(mod, module_classes_to_override): overridden_module_classes.add(type(mod)) diff --git a/torch/distributed/fsdp/_debug_utils.py b/torch/distributed/fsdp/_debug_utils.py index 0b014f7a2ba..2103da08a97 100644 --- a/torch/distributed/fsdp/_debug_utils.py +++ b/torch/distributed/fsdp/_debug_utils.py @@ -2,9 +2,9 @@ import logging import time from collections import defaultdict +from collections.abc import Iterator from contextlib import contextmanager from enum import Enum -from typing import Dict, Iterator, List, Set import torch import torch.distributed as dist @@ -28,8 +28,8 @@ class SimpleProfiler: H2D = "H2D" D2H = "D2H" - results: Dict[str, float] = defaultdict(float) - profiling: Set[str] = set() + results: dict[str, float] = defaultdict(float) + profiling: set[str] = set() @classmethod def reset(cls) -> None: @@ -65,7 +65,7 @@ class SimpleProfiler: def _get_sharded_module_tree_with_module_name_to_fqns( model: torch.nn.Module, -) -> tuple[str, Dict[str, List[str]]]: +) -> tuple[str, dict[str, list[str]]]: """ It is used for composable fully_shard() code path, it returns 1. sharded module tree info: each line reprents a submodule name that contats the @@ -143,10 +143,10 @@ def _get_sharded_module_tree_with_module_name_to_fqns( return sharded_tree_info[0], sharded_module_name_to_fqns # Use List to mutate its value in place while running the recursive functions - sharded_tree_info: List[str] = [ + sharded_tree_info: list[str] = [ "", ] - sharded_module_name_to_fqns: Dict[str, List[str]] = {} + sharded_module_name_to_fqns: dict[str, list[str]] = {} return _apply_to_modules( model, module_fn, diff --git a/torch/distributed/fsdp/_dynamo_utils.py b/torch/distributed/fsdp/_dynamo_utils.py index 38bb95596b0..77bcd43b63b 100644 --- a/torch/distributed/fsdp/_dynamo_utils.py +++ b/torch/distributed/fsdp/_dynamo_utils.py @@ -1,11 +1,9 @@ -from typing import Set - import torch.nn as nn def _annotate_modules_for_dynamo( module: nn.Module, - ignored_modules: Set[nn.Module], + ignored_modules: set[nn.Module], use_orig_params: bool, ) -> None: """ diff --git a/torch/distributed/fsdp/_exec_order_utils.py b/torch/distributed/fsdp/_exec_order_utils.py index 2405fae46b2..b19e919de29 100644 --- a/torch/distributed/fsdp/_exec_order_utils.py +++ b/torch/distributed/fsdp/_exec_order_utils.py @@ -2,7 +2,7 @@ import itertools import warnings from enum import auto, Enum -from typing import Dict, List, Optional, Union +from typing import Optional, Union import torch import torch.distributed as dist @@ -37,9 +37,9 @@ class _ExecOrderData: ) -> None: # Tracks the (static) pre-forward order for execution order validation # and forward prefetching - self.handles_pre_forward_order: List[FlatParamHandle] = [] + self.handles_pre_forward_order: list[FlatParamHandle] = [] # Tracks the post-forward order for pre-backward prefetching - self.handles_post_forward_order: List[Optional[FlatParamHandle]] = [] + self.handles_post_forward_order: list[Optional[FlatParamHandle]] = [] self._iter = 0 # Gives the max number of backward/forward prefetched all-gathers by a @@ -51,9 +51,9 @@ class _ExecOrderData: self._checking_order: bool = debug_level == dist.DebugLevel.DETAIL self.process_group: Optional[dist.ProcessGroup] = None self.world_size: Optional[int] = None - self.all_handles: List[FlatParamHandle] = [] + self.all_handles: list[FlatParamHandle] = [] # Names are prefixed from the root module - self.param_to_fqn: Dict[nn.Parameter, List[str]] = {} + self.param_to_fqn: dict[nn.Parameter, list[str]] = {} # Current index in the pre-forward execution order self.current_order_index = 0 self.warn_status = _ExecOrderWarnStatus.NONE @@ -197,7 +197,7 @@ class _ExecOrderData: num_valid_indices = sum( (index is not None) for index in optional_local_indices ) - tensor_kwargs: Dict[str, Union[torch.dtype, torch.device]] = { + tensor_kwargs: dict[str, Union[torch.dtype, torch.device]] = { "dtype": torch.int32, "device": device, } @@ -313,7 +313,7 @@ class _ExecOrderData: corresponding to the handles in ``handle``. An entry in the returned tuple is ``None`` if the handle is invalid. """ - indices: List[Optional[int]] = [] + indices: list[Optional[int]] = [] if handle: indices.append(handle._handle_index) return tuple(indices) @@ -321,13 +321,13 @@ class _ExecOrderData: def _get_names_from_handle_indices( self, handle_indices: tuple[int, ...], - ) -> List[List[str]]: + ) -> list[list[str]]: """ Returns a list of FQNs for each handle in ``handle_indices``. If a handle index is invalid, then its FQNs are omitted from the returned list. """ - fqns: List[List[str]] = [] + fqns: list[list[str]] = [] for index in handle_indices: if index is None or index < 0 or index >= len(self.all_handles): continue @@ -339,12 +339,12 @@ class _ExecOrderData: def _get_names_from_handles( self, handle: FlatParamHandle, - ) -> List[List[str]]: + ) -> list[list[str]]: """ Returns a list of FQNs for each handle in ``handles_key``. If a handle is invalid, then its FQNs are omitted from the returned list. """ - fqns: List[List[str]] = [] + fqns: list[list[str]] = [] if handle: flat_param = handle.flat_param if flat_param in self.param_to_fqn: diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index 98a2fc3cd82..85d85c3f2e0 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -4,24 +4,10 @@ import functools import logging import os import warnings +from collections.abc import Generator, Iterator, Sequence from enum import auto, Enum from itertools import accumulate, chain -from typing import ( - Any, - Callable, - cast, - Dict, - Generator, - Iterator, - List, - NamedTuple, - no_type_check, - Optional, - Sequence, - Set, - Tuple, - Union, -) +from typing import Any, Callable, cast, NamedTuple, no_type_check, Optional, Union import torch import torch.distributed as dist @@ -354,7 +340,7 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta): _numels: tuple[int, ...] _shard_param_infos: tuple[_ShardParamInfo, ...] _shared_param_infos: tuple[SharedParamInfo, ...] - _modules: Set[nn.Module] + _modules: set[nn.Module] _shard_numel_padded: int _local_shard: Tensor _full_param_padded: Tensor @@ -366,12 +352,12 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta): _mp_shard: Tensor _cpu_grad: Tensor _saved_grad_shard: Tensor - _params: Optional[List[nn.Parameter]] - _shared_params: Optional[List[nn.Parameter]] - _tensors: Optional[List[Optional[Tensor]]] - _is_grad_none_mask: Optional[List[bool]] + _params: Optional[list[nn.Parameter]] + _shared_params: Optional[list[nn.Parameter]] + _tensors: Optional[list[Optional[Tensor]]] + _is_grad_none_mask: Optional[list[bool]] - _is_padding_mask: List[bool] + _is_padding_mask: list[bool] def __new__(cls, data=None, requires_grad=True): assert cls is FlatParameter, "subclasses FlatParameter not supported" @@ -386,17 +372,17 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta): def _init_metadata( cls, self, - param_infos: List[ParamInfo], - numels: List[int], - shapes: List[torch.Size], - strides: List[tuple[int, ...]], - contiguities: List[bool], - fqns: List[str], - shared_param_infos: List[SharedParamInfo], - param_extensions: List[Optional[Any]], - params: Optional[List[nn.Parameter]], - shared_params: Optional[List[nn.Parameter]], - is_padding_mask: List[bool], + param_infos: list[ParamInfo], + numels: list[int], + shapes: list[torch.Size], + strides: list[tuple[int, ...]], + contiguities: list[bool], + fqns: list[str], + shared_param_infos: list[SharedParamInfo], + param_extensions: list[Optional[Any]], + params: Optional[list[nn.Parameter]], + shared_params: Optional[list[nn.Parameter]], + is_padding_mask: list[bool], ) -> None: """ Initialize attributes holding metadata about the original parameters comprising the flat parameter. @@ -426,7 +412,7 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta): self._param_extensions = param_extensions self._is_padding_mask = is_padding_mask - numels_without_padding: List[int] = [] + numels_without_padding: list[int] = [] for numel, is_padding in zip(numels, is_padding_mask): if not is_padding: numels_without_padding.append(numel) @@ -624,7 +610,7 @@ class FlatParamHandle: def _init_flat_param_and_metadata( self, - params: List[Union[Tensor, nn.Parameter]], + params: list[Union[Tensor, nn.Parameter]], module: nn.Module, aligned_numel: int, use_orig_params: bool, @@ -653,20 +639,20 @@ class FlatParamHandle: params_set = set(params) # For alignment padding, only `numels` gets strictly non-`None` # elements, and all other lists get `None` elements for padding. - param_infos: List[ParamInfo] = [] - numels: List[int] = [] - shapes: List[torch.Size] = [] - strides: List[tuple[int, ...]] = [] - contiguities: List[bool] = [] - fqns: List[str] = [] - shared_param_infos: List[SharedParamInfo] = [] - shared_param_memo: Dict[ + param_infos: list[ParamInfo] = [] + numels: list[int] = [] + shapes: list[torch.Size] = [] + strides: list[tuple[int, ...]] = [] + contiguities: list[bool] = [] + fqns: list[str] = [] + shared_param_infos: list[SharedParamInfo] = [] + shared_param_memo: dict[ Union[Tensor, nn.Parameter], tuple[nn.Module, str, str] ] = {} - params_to_flatten: List[Union[Tensor, nn.Parameter]] = [] - shared_params: List[Union[Tensor, nn.Parameter]] = [] - param_extensions: List[Any] = [] - is_padding_mask: List[bool] = [] + params_to_flatten: list[Union[Tensor, nn.Parameter]] = [] + shared_params: list[Union[Tensor, nn.Parameter]] = [] + param_extensions: list[Any] = [] + is_padding_mask: list[bool] = [] total_numel = total_numel_without_padding = 0 for submodule_name, submodule in module.named_modules(remove_duplicate=False): for param_name, param in _named_parameters_with_duplicates( @@ -779,8 +765,8 @@ class FlatParamHandle: ) def _validate_tensors_to_flatten( - self, tensors: List[Union[Tensor, nn.Parameter]] - ) -> Tuple: + self, tensors: list[Union[Tensor, nn.Parameter]] + ) -> tuple: """Validate the tensors to flatten and returns any necessary metadata.""" dtype: Optional[torch.dtype] = None # Return as the logical OR over each tensor's value @@ -819,7 +805,7 @@ class FlatParamHandle: def flatten_tensors( self, - tensors: List[Tensor], + tensors: list[Tensor], aligned_numel: int, ) -> Tensor: """ @@ -841,7 +827,7 @@ class FlatParamHandle: f"Expects non-negative `aligned_numel` but got {aligned_numel}" ) dtype, _, device = self._validate_tensors_to_flatten(tensors) - flat_tensors: List[Tensor] = [] + flat_tensors: list[Tensor] = [] if aligned_numel > 0: total_numel = 0 for tensor in tensors: @@ -876,7 +862,7 @@ class FlatParamHandle: def flatten_tensors_into_flat_param( self, - tensors: List[Tensor], + tensors: list[Tensor], aligned_numel: int, requires_grad: bool, ) -> FlatParameter: @@ -1013,7 +999,7 @@ class FlatParamHandle: assert len(flat_param_offsets) == len( self.flat_param._numels_with_padding ), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}" - shard_param_infos: List[_ShardParamInfo] = [] + shard_param_infos: list[_ShardParamInfo] = [] sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1 # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices # into the unsharded flat parameter (inclusive) of the given parameter @@ -1130,7 +1116,7 @@ class FlatParamHandle: assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}" return torch.Size([unpadded_sharded_size[0] + numel_to_pad]) - def _get_flat_param_offsets(self) -> List[tuple[int, int]]: + def _get_flat_param_offsets(self) -> list[tuple[int, int]]: """ Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding). @@ -1884,7 +1870,7 @@ class FlatParamHandle: def _get_unflat_views_aligned( self, tensor: Optional[Tensor] = None, - ) -> List[Tensor]: + ) -> list[Tensor]: """ Return unflattened ``Tensor`` views into ``tensor`` with handling for padding. @@ -1895,11 +1881,11 @@ class FlatParamHandle: flat_param = self.flat_param if tensor is None: tensor = flat_param - splits: List[Tensor] = torch.split( + splits: list[Tensor] = torch.split( tensor, flat_param._numels_with_padding, dim=0 ) idx = 0 - views: List[Tensor] = [] + views: list[Tensor] = [] for split, is_padding in zip(splits, flat_param._is_padding_mask): if is_padding: continue @@ -2462,7 +2448,7 @@ class FlatParamHandle: else: self._use_unsharded_views(as_params=True) - def _get_modules(self) -> Set[nn.Module]: + def _get_modules(self) -> set[nn.Module]: """Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter.""" return {pi.module for pi in self.flat_param._param_infos}.union( {spi.module for spi in self.flat_param._shared_param_infos} @@ -2514,9 +2500,9 @@ class FlatParamHandle: yield (param_name, module_name) @property - def _fqns_in_shard(self) -> List[str]: + def _fqns_in_shard(self) -> list[str]: """Return the FQNs of the parameters present in this rank's shard.""" - fqns_in_shard: List[str] = [] + fqns_in_shard: list[str] = [] for fqn, shard_param_info in zip( self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined] ): @@ -2708,8 +2694,8 @@ def _safe_setattr_tensor_or_param( def _convert_to_params( - tensors: List[Union[torch.Tensor, nn.Parameter]] -) -> List[nn.Parameter]: + tensors: list[Union[torch.Tensor, nn.Parameter]] +) -> list[nn.Parameter]: return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors] diff --git a/torch/distributed/fsdp/_fsdp_extensions.py b/torch/distributed/fsdp/_fsdp_extensions.py index 34edfbe904a..f861a90ce58 100644 --- a/torch/distributed/fsdp/_fsdp_extensions.py +++ b/torch/distributed/fsdp/_fsdp_extensions.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any, Optional import torch import torch.distributed as dist @@ -64,7 +64,7 @@ class FSDPExtensions(ABC): def pre_load_state_dict_transform( self, tensor: torch.Tensor, - ) -> tuple[torch.Tensor, List[Shard]]: + ) -> tuple[torch.Tensor, list[Shard]]: """ This is to be called before loading a *sharded* model state dict and should return the tensor and list of shards from which to load data. @@ -157,7 +157,7 @@ def _ext_chunk_dtensor( def _ext_pre_load_state_dict_transform( tensor: torch.Tensor, fsdp_extension: Optional[FSDPExtensions] = None, -) -> tuple[torch.Tensor, List[Shard]]: +) -> tuple[torch.Tensor, list[Shard]]: if fsdp_extension is not None: return fsdp_extension.pre_load_state_dict_transform(tensor) diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py index 73cce0b320c..f24cdd939d3 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -1,4 +1,4 @@ -from typing import cast, List, NamedTuple, Optional, Union +from typing import cast, NamedTuple, Optional, Union import torch import torch.distributed as dist @@ -20,12 +20,12 @@ class AllGatherResult(NamedTuple): all_gather_event: Optional[torch.Event] all_gather_work: Optional[dist.distributed_c10d.Work] # For each parameter, the all-gather input dtype for each input - param_all_gather_input_dtypes: List[List[torch.dtype]] + param_all_gather_input_dtypes: list[list[torch.dtype]] # For each parameter, the all-gather input numel for each input - param_all_gather_input_numels: List[List[int]] + param_all_gather_input_numels: list[list[int]] # 1D flattened version of `param_all_gather_input_numels` saved to avoid # CPU overhead from recomputing - all_gather_input_split_sizes: List[int] + all_gather_input_split_sizes: list[int] lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 @@ -47,8 +47,8 @@ lib.define( @torch.library.impl(lib, "all_gather_copy_in", "Meta") def all_gather_copy_in_meta( - all_gather_inputs: List[torch.Tensor], - inp_split_sizes: List[int], + all_gather_inputs: list[torch.Tensor], + inp_split_sizes: list[int], all_gather_input_numel: int, world_size: int, rank: int, @@ -68,8 +68,8 @@ def all_gather_copy_in_meta( @torch.library.impl(lib, "all_gather_copy_in", "XPU") @torch.library.impl(lib, "all_gather_copy_in", "CPU") def all_gather_copy_in_cuda( - all_gather_inputs: List[torch.Tensor], - inp_split_sizes: List[int], + all_gather_inputs: list[torch.Tensor], + inp_split_sizes: list[int], all_gather_input_numel: int, world_size: int, rank: int, @@ -99,9 +99,9 @@ lib.define( @torch.library.impl(lib, "split_with_sizes_copy", "CPU") def split_with_sizes_copy( all_gather_output: torch.Tensor, - all_gather_input_split_sizes: List[int], + all_gather_input_split_sizes: list[int], dim: int, - out: List[torch.Tensor], + out: list[torch.Tensor], ) -> None: torch.split_with_sizes_copy( all_gather_output, all_gather_input_split_sizes, dim=dim, out=out @@ -118,7 +118,7 @@ lib.define( @torch.library.impl(lib, "chunk_cat", "XPU") @torch.library.impl(lib, "chunk_cat", "CPU") def chunk_cat( - tensors: List[torch.Tensor], + tensors: list[torch.Tensor], dim: int, num_chunks: int, out: torch.Tensor, @@ -128,7 +128,7 @@ def chunk_cat( @torch.no_grad() def foreach_all_gather( - fsdp_params: List[FSDPParam], + fsdp_params: list[FSDPParam], group: dist.ProcessGroup, async_op: bool, all_gather_copy_in_stream: torch.Stream, @@ -183,8 +183,8 @@ def foreach_all_gather( @torch.no_grad() def _get_param_all_gather_inputs( - fsdp_params: List[FSDPParam], -) -> List[List[torch.Tensor]]: + fsdp_params: list[FSDPParam], +) -> list[list[torch.Tensor]]: if compiled_autograd_enabled(): return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params] @@ -198,10 +198,10 @@ def _get_param_all_gather_inputs( and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather") ) - param_all_gather_inputs: List[List[torch.Tensor]] = [[] for _ in fsdp_params] - foreach_copy_indices: List[int] = [] - foreach_copy_inputs: List[torch.Tensor] = [] - foreach_copy_input_numels: List[int] = [] + param_all_gather_inputs: list[list[torch.Tensor]] = [[] for _ in fsdp_params] + foreach_copy_indices: list[int] = [] + foreach_copy_inputs: list[torch.Tensor] = [] + foreach_copy_input_numels: list[int] = [] # 1st pass: for foreach-copy parameters, get inputs and metadata for the # foreach copy, and for the others, actually get their all-gather inputs @@ -236,7 +236,7 @@ def _get_param_all_gather_inputs( @torch.no_grad() def foreach_all_gather_copy_out( all_gather_result: AllGatherResult, - fsdp_params: List[FSDPParam], + fsdp_params: list[FSDPParam], group: dist.ProcessGroup, ) -> None: ( @@ -255,8 +255,8 @@ def foreach_all_gather_copy_out( all_gather_work.wait() world_size, device = group.size(), all_gather_output.device - split_with_sizes_out: List[torch.Tensor] = [] - shard_i_copy_infos: List[tuple[FSDPParam, List[torch.Tensor]]] = [] + split_with_sizes_out: list[torch.Tensor] = [] + shard_i_copy_infos: list[tuple[FSDPParam, list[torch.Tensor]]] = [] for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip( param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params ): @@ -320,8 +320,8 @@ def foreach_all_gather_copy_out( @torch.no_grad() def foreach_reduce( - fsdp_params: List[FSDPParam], - unsharded_grads: List[torch.Tensor], + fsdp_params: list[FSDPParam], + unsharded_grads: list[torch.Tensor], reduce_scatter_group: dist.ProcessGroup, reduce_scatter_stream: torch.Stream, orig_dtype: torch.dtype, @@ -488,7 +488,7 @@ def foreach_reduce( def foreach_reduce_scatter_copy_in( - unsharded_grads: List[torch.Tensor], + unsharded_grads: list[torch.Tensor], reduce_scatter_input: torch.Tensor, world_size: int, ) -> None: @@ -499,14 +499,14 @@ def foreach_reduce_scatter_copy_in( def _get_all_gather_input_metadatas( - param_all_gather_inputs: List[List[torch.Tensor]], -) -> tuple[List[List[torch.dtype]], List[List[int]], torch.dtype]: - param_all_gather_input_dtypes: List[List[torch.dtype]] = [] - param_all_gather_input_numels: List[List[int]] = [] + param_all_gather_inputs: list[list[torch.Tensor]], +) -> tuple[list[list[torch.dtype]], list[list[int]], torch.dtype]: + param_all_gather_input_dtypes: list[list[torch.dtype]] = [] + param_all_gather_input_numels: list[list[int]] = [] all_gather_dtype = param_all_gather_inputs[0][0].dtype for all_gather_inputs in param_all_gather_inputs: - input_dtypes: List[torch.dtype] = [] - input_numels: List[int] = [] + input_dtypes: list[torch.dtype] = [] + input_numels: list[int] = [] for all_gather_input in all_gather_inputs: if all_gather_input.dtype != all_gather_dtype: all_gather_dtype = torch.uint8 diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py index b630217015b..1f6c3784968 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py @@ -3,7 +3,7 @@ import math import traceback from dataclasses import dataclass from enum import auto, Enum -from typing import Any, List, Optional +from typing import Any, Optional import torch import torch.distributed as dist @@ -120,7 +120,7 @@ def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Si def _chunk_with_empty( tensor: torch.Tensor, num_chunks: int, dim: int -) -> List[torch.Tensor]: +) -> list[torch.Tensor]: chunks = list(torch.chunk(tensor, num_chunks, dim=dim)) while len(chunks) < num_chunks: chunks.append(chunks[0].new_empty(0)) diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_init.py b/torch/distributed/fsdp/_fully_shard/_fsdp_init.py index 2c8c4cfe21e..63657b8692f 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_init.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_init.py @@ -1,5 +1,5 @@ import itertools -from typing import List, Optional, Set, Union +from typing import Optional, Union import torch import torch.distributed as dist @@ -70,11 +70,11 @@ def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device: return torch.device(mesh.device_type, device_handle.current_device()) -def _get_managed_modules(root_modules: tuple[nn.Module, ...]) -> List[nn.Module]: - modules: List[nn.Module] = [] +def _get_managed_modules(root_modules: tuple[nn.Module, ...]) -> list[nn.Module]: + modules: list[nn.Module] = [] root_modules_set = set(root_modules) # Track visisted modules to avoid visiting shared modules multiple times - visited_modules: Set[nn.Module] = set() + visited_modules: set[nn.Module] = set() def dfs(module: nn.Module) -> None: """ @@ -113,14 +113,14 @@ def _verify_managed_param(name: str, param: nn.Parameter) -> None: def _get_managed_states( - modules: List[nn.Module], -) -> tuple[List[nn.Parameter], List[torch.Tensor]]: - params: List[nn.Parameter] = [] - buffers: List[torch.Tensor] = [] + modules: list[nn.Module], +) -> tuple[list[nn.Parameter], list[torch.Tensor]]: + params: list[nn.Parameter] = [] + buffers: list[torch.Tensor] = [] # Track visited parameters/buffers to avoid visiting shared parameters and # buffers multiple times - visited_params: Set[nn.Parameter] = set() - visited_buffers: Set[torch.Tensor] = set() + visited_params: set[nn.Parameter] = set() + visited_buffers: set[torch.Tensor] = set() for module in modules: for name, param in module.named_parameters(recurse=False): if param not in visited_params: @@ -135,8 +135,8 @@ def _get_managed_states( def _move_states_to_device( - params: List[nn.Parameter], - buffers: List[torch.Tensor], + params: list[nn.Parameter], + buffers: list[torch.Tensor], device: torch.device, ) -> None: """ diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index cbaebfcffb2..b20180893a9 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -1,9 +1,10 @@ # mypy: allow-untyped-defs import inspect import itertools +from collections.abc import Sequence from dataclasses import dataclass, field from enum import auto, Enum -from typing import Any, Callable, cast, List, Optional, Sequence +from typing import Any, Callable, cast, Optional import torch import torch.nn as nn @@ -171,8 +172,8 @@ class ParamModuleInfo: # Parameter names are unprefixed, e.g. "weight", not "lin.weight" module: nn.Module param_name: str - shared_modules: List[nn.Module] = field(default_factory=list) - shared_param_names: List[str] = field(default_factory=list) + shared_modules: list[nn.Module] = field(default_factory=list) + shared_param_names: list[str] = field(default_factory=list) @dataclass @@ -211,10 +212,10 @@ class FSDPParam: _sharding_spec: DTensorSpec # DTensor attributes (only defined for DTensor `param`): _tp_spec: DTensorSpec - all_gather_outputs: List[torch.Tensor] # 1D + all_gather_outputs: list[torch.Tensor] # 1D # All-gather extension attributes _extensions_data: ExtensionsData - _unsharded_inner_tensors: List[torch.Tensor] + _unsharded_inner_tensors: list[torch.Tensor] def __init__( self, @@ -241,7 +242,7 @@ class FSDPParam: if self.post_forward_mesh_info: self._init_sharded_post_forward_param_metadata(param) self._init_extensions() - self.all_gather_outputs: List[torch.Tensor] = [] + self.all_gather_outputs: list[torch.Tensor] = [] self.unsharded_accumulated_grad = None self._param_fqn: Optional[str] = None # prefixed from root module # TODO: Remove this padding logic once DTensor pads the local tensor: @@ -439,12 +440,12 @@ class FSDPParam: ) if has_fsdp_pre_all_gather: self._extensions_data = ExtensionsData() - self._unsharded_inner_tensors: List[torch.Tensor] = [] + self._unsharded_inner_tensors: list[torch.Tensor] = [] def init_all_gather_outputs( self, - all_gather_input_numels: List[int], - all_gather_input_dtypes: List[torch.dtype], + all_gather_input_numels: list[int], + all_gather_input_dtypes: list[torch.dtype], world_size: int, device: torch.device, force_recreate: bool = False, @@ -680,7 +681,7 @@ class FSDPParam: free_storage(tensor) @property - def all_gather_inputs(self) -> List[torch.Tensor]: # 1D + def all_gather_inputs(self) -> list[torch.Tensor]: # 1D self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD) if self.sharded_state == ShardedState.SHARDED: if not compiled_autograd_enabled() and hasattr( diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py index 187074b9e8e..919107ca910 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import contextlib import logging -from typing import Any, Callable, cast, Dict, List, NamedTuple, Optional, Set +from typing import Any, Callable, cast, NamedTuple, Optional import torch import torch.distributed as dist @@ -31,7 +31,7 @@ from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState logger = logging.getLogger("torch.distributed.fsdp.fully_shard") -_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict +_ModuleToHandleDict = dict[nn.Module, RemovableHandle] # for state dict """ @@ -77,7 +77,7 @@ class FSDPCommContext: self.all_gather_state: Optional[AllGatherState] = None self.reduce_scatter_state: Optional[ReduceScatterState] = None # Post-forward order for explicit backward prefetching - self.post_forward_order: List[FSDPParamGroup] = [] # will cause ref cycles + self.post_forward_order: list[FSDPParamGroup] = [] # will cause ref cycles def get_all_gather_streams( self, async_op: bool, training_state: TrainingState @@ -116,7 +116,7 @@ class FSDPParamGroup: def __init__( self, - params: List[nn.Parameter], + params: list[nn.Parameter], modules: tuple[nn.Module, ...], mesh_info: FSDPMeshInfo, post_forward_mesh_info: Optional[FSDPMeshInfo], @@ -162,7 +162,7 @@ class FSDPParamGroup: # - Communication and communication/computation overlap self.comm_ctx = FSDPCommContext() # Group's indices in the shared post-forward order - self._post_forward_indices: List[int] = [] + self._post_forward_indices: list[int] = [] # Whether to reduce gradients at all (whether for FSDP or HSDP) self.reduce_grads: bool = True # Whether to all-reduce gradients for HSDP; only used if @@ -325,8 +325,8 @@ class FSDPParamGroup: self._to_sharded() def pre_forward( - self, module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any] - ) -> tuple[tuple[Any, ...], Dict[str, Any]]: + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: if not compiled_autograd_enabled(): logger.debug("%s", self._with_fqn("FSDP::pre_forward")) with record_function(self._with_fqn("FSDP::pre_forward")): @@ -389,8 +389,8 @@ class FSDPParamGroup: return # Save the autograd-computed gradients before resharding to only # access the unsharded parameters when their data is present - fsdp_params_with_grad: List[FSDPParam] = [] - unsharded_grads: List[torch.Tensor] = [] + fsdp_params_with_grad: list[FSDPParam] = [] + unsharded_grads: list[torch.Tensor] = [] for fsdp_param in self.fsdp_params: if not hasattr(fsdp_param, "_unsharded_param"): continue @@ -543,8 +543,8 @@ class FSDPParamGroup: # Hook Registration # def _register_post_backward_hook( - self, args: tuple[Any, ...], kwargs: Dict[str, Any] - ) -> tuple[tuple[Any, ...], Dict[str, Any]]: + self, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: # Traceable FSDP2 relies on `root_post_backward_callback` to call each # `FSDPParamGroup.post_backward` if (not torch._dynamo.config.skip_fsdp_hooks) or compiled_autograd_enabled(): @@ -554,8 +554,8 @@ class FSDPParamGroup: args_list, args_spec = tree_flatten(args) kwargs_list, kwargs_spec = tree_flatten(kwargs) args_kwargs_list = list(args_list) + list(kwargs_list) - inp_tensor_indices: List[int] = [] - inp_tensors: List[torch.Tensor] = [] + inp_tensor_indices: list[int] = [] + inp_tensors: list[torch.Tensor] = [] for i, obj in enumerate(args_kwargs_list): if torch.is_tensor(obj) and obj.requires_grad: inp_tensor_indices.append(i) @@ -579,7 +579,7 @@ class FSDPParamGroup: ), f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}" if num_pre_save_hooks > 0: return # already registered - modules_with_fsdp_params: Set[nn.Module] = { + modules_with_fsdp_params: set[nn.Module] = { fsdp_param._module_info.module for fsdp_param in self.fsdp_params } @@ -670,8 +670,8 @@ class FSDPParamGroup: def _get_param_module_infos( - params: List[nn.Parameter], modules: tuple[nn.Module, ...] -) -> List[ParamModuleInfo]: + params: list[nn.Parameter], modules: tuple[nn.Module, ...] +) -> list[ParamModuleInfo]: """ Shared parameter: lin1.weight = lin2.weight Shared module: mlp.lin1 = mlp.lin2 @@ -679,7 +679,7 @@ def _get_param_module_infos( find shared modules' parameters and shared parameters within a module. """ params_set = set(params) - param_to_module_info: Dict[nn.Parameter, ParamModuleInfo] = {} + param_to_module_info: dict[nn.Parameter, ParamModuleInfo] = {} for module in modules: for _, submodule in module.named_modules(remove_duplicate=False): for param_name, param in _named_parameters_with_duplicates( diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py index e0785c6ac45..61ed6d8f153 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py @@ -2,7 +2,8 @@ # mypy: allow-untyped-defs import functools import logging -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, TYPE_CHECKING +from collections.abc import Sequence +from typing import Any, Callable, Optional, TYPE_CHECKING import torch import torch.nn as nn @@ -40,7 +41,7 @@ class FSDPStateContext: def __init__(self) -> None: # All FSDP states in the root state's module tree - self.all_states: List[FSDPState] = [] + self.all_states: list[FSDPState] = [] # Iteration's forward root runs the once-per-forward logic; this root # may not be the overall root set by lazy initialization in cases where # only a submodule runs forward (e.g. encoder-only for eval) @@ -73,9 +74,9 @@ class FSDPState(_State): self._state_ctx = FSDPStateContext() self._comm_ctx = FSDPCommContext() self._training_state: TrainingState = TrainingState.IDLE - self._states_to_forward_prefetch: List[FSDPState] = [] - self._states_to_backward_prefetch: List[FSDPState] = [] - self._modules_to_run_forward: Set[nn.Module] = set() + self._states_to_forward_prefetch: list[FSDPState] = [] + self._states_to_backward_prefetch: list[FSDPState] = [] + self._modules_to_run_forward: set[nn.Module] = set() # Define a separate init since `__init__` is called in the contract def init( @@ -108,8 +109,8 @@ class FSDPState(_State): self._post_forward_hook_handle = hook_handle def _root_pre_forward( - self, module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any] - ) -> tuple[tuple[Any, ...], Dict[str, Any]]: + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: self._lazy_init() if self._state_ctx.iter_forward_root is not None: return args, kwargs @@ -150,7 +151,7 @@ class FSDPState(_State): ) detect_compiled_autograd() root_module = self._modules[0] - visited_states: Set[FSDPState] = set() + visited_states: set[FSDPState] = set() for module_name, module in root_module.named_modules(): if (state := _get_module_fsdp_state(module)) is None: continue @@ -188,8 +189,8 @@ class FSDPState(_State): """Sets module and parameter FQN attributes for debugging.""" assert self._is_root root_module = self._modules[0] - param_to_fsdp_param: Dict[nn.Parameter, FSDPParam] = {} - module_to_fsdp_param_group: Dict[nn.Module, FSDPParamGroup] = {} + param_to_fsdp_param: dict[nn.Parameter, FSDPParam] = {} + module_to_fsdp_param_group: dict[nn.Module, FSDPParamGroup] = {} for state in self._state_ctx.all_states: if fsdp_param_group := state._fsdp_param_group: for fsdp_param in fsdp_param_group.fsdp_params: @@ -211,8 +212,8 @@ class FSDPState(_State): @disable_if_config_true def _pre_forward( - self, module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any] - ) -> tuple[tuple[Any, ...], Dict[str, Any]]: + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: # When composing with module-hook-based activation checkpointing, the # the pre-backward hook is responsible for the unshard if self._training_state == TrainingState.PRE_BACKWARD: @@ -343,7 +344,7 @@ def _register_group_forward_hooks( modules: Sequence[nn.Module], pre_hook: Callable, post_hook: Callable, - modules_to_run: Set[nn.Module], + modules_to_run: set[nn.Module], ): """ Registers group forward pre and post-hooks. The pre-hook runs upon the diff --git a/torch/distributed/fsdp/_fully_shard/_fully_shard.py b/torch/distributed/fsdp/_fully_shard/_fully_shard.py index 1a51ae06847..02b48bd01f1 100644 --- a/torch/distributed/fsdp/_fully_shard/_fully_shard.py +++ b/torch/distributed/fsdp/_fully_shard/_fully_shard.py @@ -1,18 +1,8 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import functools -from typing import ( - Any, - Callable, - cast, - Dict, - Iterable, - List, - NoReturn, - Optional, - Type, - Union, -) +from collections.abc import Iterable +from typing import Any, Callable, cast, NoReturn, Optional, Union import torch import torch.nn as nn @@ -42,14 +32,14 @@ __all__ = [ ] -cls_to_fsdp_cls: Dict[Type, Type] = {} +cls_to_fsdp_cls: dict[type, type] = {} # The decorator adds a state object to `module` that can be accessed via # `fully_shard.state(module)`. The state object and module are 1:1. @contract(state_cls=FSDPState) def fully_shard( - module: Union[nn.Module, List[nn.Module]], + module: Union[nn.Module, list[nn.Module]], *, mesh: Optional[DeviceMesh] = None, reshard_after_forward: Union[bool, int] = True, @@ -331,7 +321,7 @@ class FSDPModule: if fsdp_param_group := state._fsdp_param_group: fsdp_param_group.reshard_after_backward = reshard_after_backward - def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None: + def set_modules_to_forward_prefetch(self, modules: list["FSDPModule"]) -> None: """ Sets the FSDP modules for which this FSDP module should explicitly prefetch all-gathers in forward. The prefetching runs after this @@ -351,7 +341,7 @@ class FSDPModule: module._get_fsdp_state() for module in modules ] - def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None: + def set_modules_to_backward_prefetch(self, modules: list["FSDPModule"]) -> None: """ Sets the FSDP modules for which this FSDP module should explicitly prefetch all-gathers in backward. This overrides the default backward diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 6e1ec4cbf26..cab8c4ec835 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -3,21 +3,8 @@ import collections import itertools import os import warnings -from typing import ( - Any, - Callable, - Deque, - Dict, - Generator, - Iterable, - Iterator, - List, - no_type_check, - Optional, - Set, - TYPE_CHECKING, - Union, -) +from collections.abc import Generator, Iterable, Iterator +from typing import Any, Callable, no_type_check, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist @@ -329,7 +316,7 @@ def _init_ignored_module_states( def _check_ignored_states( - ignored_states: List[Any], passed_as_ignored_states: bool + ignored_states: list[Any], passed_as_ignored_states: bool ) -> None: """ Check that the ignored states are uniformly parameters or uniformly modules. @@ -361,7 +348,7 @@ def _check_ignored_states( def _init_device_handle( state: _FSDPState, module: nn.Module, - ignored_params: Set[nn.Parameter], + ignored_params: set[nn.Parameter], device_id: Optional[Union[int, torch.device]], ) -> _FSDPState: """ @@ -416,7 +403,7 @@ def _init_buffer_state( # `module`) to its original dtype for restoring that dtype during model # checkpointing when buffer mixed precision is enabled. The names should # be clean since the casting happens in a `summon_full_params()` context. - _buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {} + _buffer_name_to_orig_dtype: dict[str, torch.dtype] = {} for buffer_name, buffer in module.named_buffers(): buffer_name = clean_tensor_name(buffer_name) _buffer_name_to_orig_dtype[buffer_name] = buffer.dtype @@ -479,13 +466,13 @@ def _init_core_state( state._unshard_event = None # Mapping from fully sharded module to the handles it is responsible to # unshard and reshard (see [Note: Fully Sharded Module]) - _fully_sharded_module_to_handle: Dict[nn.Module, FlatParamHandle] = {} + _fully_sharded_module_to_handle: dict[nn.Module, FlatParamHandle] = {} state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle # Invariant: `state.params` contains exactly the `FlatParameter`s of the # handles in `state._handle` _handle: Optional[FlatParamHandle] = None state._handle = _handle - params: List[FlatParameter] = [] + params: list[FlatParameter] = [] state.params = params return state @@ -494,11 +481,11 @@ def _init_core_state( def _init_runtime_state( state: _FSDPState, ) -> _FSDPState: - _root_pre_forward_handles: List[RemovableHandle] = [] + _root_pre_forward_handles: list[RemovableHandle] = [] state._root_pre_forward_handles = _root_pre_forward_handles - _pre_forward_handles: List[RemovableHandle] = [] + _pre_forward_handles: list[RemovableHandle] = [] state._pre_forward_handles = _pre_forward_handles - _post_forward_handles: List[RemovableHandle] = [] + _post_forward_handles: list[RemovableHandle] = [] state._post_forward_handles = _post_forward_handles state._sync_gradients = True state._comm_hook = None @@ -542,13 +529,13 @@ def _init_state_dict_state(state: _FSDPState) -> _FSDPState: state_dict_config: StateDictConfig = FullStateDictConfig() state._optim_state_dict_config = FullOptimStateDictConfig() state._state_dict_config = state_dict_config - unshard_params_ctx: Dict[nn.Module, Generator] = {} + unshard_params_ctx: dict[nn.Module, Generator] = {} state._unshard_params_ctx = unshard_params_ctx return state -def _verify_managed_params(module: nn.Module, params: List[nn.Parameter]) -> None: +def _verify_managed_params(module: nn.Module, params: list[nn.Parameter]) -> None: """ Verify if the parameters are accepted by FSDP. The only restriction now is that the parameter cannot be a scalar tensor (param.shape == []). @@ -639,7 +626,7 @@ def _init_param_handle_from_module( @no_type_check def _init_param_handle_from_params( state: _FSDPState, - params: List[nn.Parameter], + params: list[nn.Parameter], fully_sharded_module: nn.Module, ): if len(params) == 0: @@ -670,7 +657,7 @@ def _init_param_handle_from_params( def _get_ignored_modules( root_module: nn.Module, _ignored_modules: Optional[Iterable[torch.nn.Module]], -) -> Set[nn.Module]: +) -> set[nn.Module]: """ Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances. @@ -726,15 +713,15 @@ def _get_ignored_modules( def _get_ignored_params( root_module: torch.nn.Module, - ignored_modules: Set[torch.nn.Module], + ignored_modules: set[torch.nn.Module], ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None, -) -> Set[torch.nn.Parameter]: +) -> set[torch.nn.Parameter]: """ Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``. :class:`FlatParameter` s are excluded from the result. """ - all_ignored_params: Set[torch.nn.Parameter] = set() + all_ignored_params: set[torch.nn.Parameter] = set() params_in_ignored_modules = { p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p) @@ -760,10 +747,10 @@ def _get_ignored_params( def _get_ignored_buffer_names( root_module: torch.nn.Module, - ignored_modules: Set[torch.nn.Module], -) -> Set[str]: + ignored_modules: set[torch.nn.Module], +) -> set[str]: """Return the cleaned buffer FQNs in ``ignored_modules``.""" - all_ignored_buffer_names: Set[str] = set() + all_ignored_buffer_names: set[str] = set() buffers_in_ignored_modules = { buffer for m in ignored_modules for buffer in m.buffers() @@ -787,7 +774,7 @@ def _get_ignored_buffer_names( return all_ignored_buffer_names -def _get_buffer_names(root_module: nn.Module) -> Set[str]: +def _get_buffer_names(root_module: nn.Module) -> set[str]: """Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`.""" return { clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers() @@ -796,7 +783,7 @@ def _get_buffer_names(root_module: nn.Module) -> Set[str]: def _check_single_device_module( module: nn.Module, - ignored_params: Set[nn.Parameter], + ignored_params: set[nn.Parameter], device_id: Optional[Union[int, torch.device]], ) -> None: """ @@ -855,8 +842,8 @@ def _get_device_from_device_id( def _need_to_materialize_module( module: nn.Module, - ignored_params: Set[nn.Parameter], - ignored_modules: Set[nn.Module], + ignored_params: set[nn.Parameter], + ignored_modules: set[nn.Module], ) -> tuple[bool, bool]: """ Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization. @@ -886,7 +873,7 @@ def _need_to_materialize_module( def _materialize_with_param_init_fn( root_module: nn.Module, param_init_fn: Callable[[nn.Module], None], - ignored_modules: Set[nn.Module], + ignored_modules: set[nn.Module], ) -> None: if not callable(param_init_fn): raise ValueError( @@ -900,7 +887,7 @@ def _materialize_with_param_init_fn( def _materialize_meta_module( root_module: nn.Module, device_from_device_id: Optional[torch.device], - ignored_modules: Set[nn.Module], + ignored_modules: set[nn.Module], device_handle: _FSDPDeviceHandle, ): # Run default meta device initialization @@ -933,13 +920,13 @@ def _materialize_meta_module( def _get_modules_to_materialize( - root_module: nn.Module, ignored_modules: Set[nn.Module] -) -> List[nn.Module]: + root_module: nn.Module, ignored_modules: set[nn.Module] +) -> list[nn.Module]: # Run BFS to collect the modules to materialize via `reset_parameters()`, # stopping at any module with FSDP already applied or at ignored modules. - modules_to_materialize: List[nn.Module] = [] + modules_to_materialize: list[nn.Module] = [] queue = collections.deque([root_module]) - visited_modules: Set[nn.Module] = {root_module} + visited_modules: set[nn.Module] = {root_module} while queue: module = queue.popleft() modules_to_materialize.append(module) @@ -956,8 +943,8 @@ def _get_modules_to_materialize( def _move_module_to_device( module: nn.Module, - ignored_params: Set[nn.Parameter], - ignored_buffers: Set[torch.Tensor], + ignored_params: set[nn.Parameter], + ignored_buffers: set[torch.Tensor], device_from_device_id: Optional[torch.device], ) -> None: """ @@ -976,10 +963,10 @@ def _move_module_to_device( if device_from_device_id is not None: # BFS from `module` without traversing any nested FSDP instances to # collect the parameters/buffers that have not yet been managed - queue: Deque[nn.Module] = collections.deque() + queue: collections.deque[nn.Module] = collections.deque() queue.append(module) - params: List[nn.Parameter] = [] - buffers: List[torch.Tensor] = [] + params: list[nn.Parameter] = [] + buffers: list[torch.Tensor] = [] while queue: curr_module = queue.popleft() # NOTE: We include a check to only move parameters/buffers that are @@ -1009,8 +996,8 @@ def _move_module_to_device( def _move_states_to_device( - params: List[nn.Parameter], - buffers: List[torch.Tensor], + params: list[nn.Parameter], + buffers: list[torch.Tensor], device_from_device_id: Optional[torch.device], ) -> None: """ @@ -1053,7 +1040,7 @@ def _warn_cpu_init(): def _get_compute_device( module: nn.Module, - ignored_params: Set[nn.Parameter], + ignored_params: set[nn.Parameter], device_from_device_id: Optional[torch.device], rank: int, device_handle: _FSDPDeviceHandle, @@ -1088,7 +1075,7 @@ def _get_compute_device( # TODO: See how to deprecate! def _sync_module_params_and_buffers( module: nn.Module, - params: List[nn.Parameter], + params: list[nn.Parameter], process_group: dist.ProcessGroup, ) -> None: """ @@ -1097,7 +1084,7 @@ def _sync_module_params_and_buffers( Precondition: ``sync_module_states == True`` and ``self.process_group`` has been set. """ - module_states: List[torch.Tensor] = [] + module_states: list[torch.Tensor] = [] for buffer in module.buffers(): # Avoid re-synchronizing buffers in case of nested wrapping if not getattr(buffer, FSDP_SYNCED, False): @@ -1131,7 +1118,7 @@ def _sync_module_params_and_buffers( def _check_module_states_for_sync_module_states( - module_states: List[torch.Tensor], + module_states: list[torch.Tensor], ) -> None: if module_states and any( tensor.device == torch.device("cpu") for tensor in module_states @@ -1145,7 +1132,7 @@ def _check_module_states_for_sync_module_states( def _get_orig_params( module: nn.Module, - ignored_params: Set[nn.Parameter], + ignored_params: set[nn.Parameter], ) -> Iterator[nn.Parameter]: """ Return an iterator over the original parameters in ``module``. @@ -1167,7 +1154,7 @@ def _get_orig_params( def _check_orig_params_flattened( fsdp_module, - ignored_params: Set[nn.Parameter], + ignored_params: set[nn.Parameter], ) -> None: """ Check that original parameters in ``fsdp_module`` have been flattened. diff --git a/torch/distributed/fsdp/_limiter_utils.py b/torch/distributed/fsdp/_limiter_utils.py index 5cc56b29f84..f9b19058534 100644 --- a/torch/distributed/fsdp/_limiter_utils.py +++ b/torch/distributed/fsdp/_limiter_utils.py @@ -1,5 +1,5 @@ import collections -from typing import Deque, Optional +from typing import Optional import torch @@ -12,7 +12,7 @@ class _FreeEventQueue: """ def __init__(self) -> None: - self._queue: Deque[torch.Event] = collections.deque() + self._queue: collections.deque[torch.Event] = collections.deque() self._max_num_inflight_all_gathers = 2 # empirically chosen def enqueue(self, free_event: torch.Event) -> None: diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 919f7f3a645..c6b9c3d1141 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -3,23 +3,10 @@ import copy import functools import logging import warnings +from collections.abc import Iterable, Iterator, Sequence from contextlib import ExitStack from dataclasses import dataclass, field -from typing import ( - Any, - cast, - Dict, - Iterable, - Iterator, - List, - NamedTuple, - no_type_check, - Optional, - Sequence, - Set, - TYPE_CHECKING, - Union, -) +from typing import Any, cast, NamedTuple, no_type_check, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist @@ -66,11 +53,11 @@ logger = logging.getLogger(__name__) class FSDPParamInfo: state: _FSDPState handle: FlatParamHandle - param_indices: Dict[str, int] - param_requires_grad: List[bool] + param_indices: dict[str, int] + param_requires_grad: list[bool] -def sorted_items(dictionary: Dict[str, Any]) -> Iterator[tuple[str, Any]]: +def sorted_items(dictionary: dict[str, Any]) -> Iterator[tuple[str, Any]]: keys = sorted(dictionary.keys()) for k in keys: yield k, dictionary[k] @@ -97,9 +84,9 @@ class _ConsolidatedOptimState: name to its value. """ - tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict) - zero_dim_tensor_state: Dict[str, torch.Tensor] = field(default_factory=dict) - non_tensor_state: Dict[str, Any] = field(default_factory=dict) + tensor_state: dict[str, torch.Tensor] = field(default_factory=dict) + zero_dim_tensor_state: dict[str, torch.Tensor] = field(default_factory=dict) + non_tensor_state: dict[str, Any] = field(default_factory=dict) class _PosDimTensorInfo(NamedTuple): @@ -131,11 +118,11 @@ class _OptimStateKey(NamedTuple): def _unflatten_optim_state( fsdp_param_info: FSDPParamInfo, - flat_param_state: Dict[str, Any], + flat_param_state: dict[str, Any], to_save: bool, shard_state: bool, cpu_offload: bool, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ Unflattens the optimizer state, consisting of the "state" part and the "param_groups" part. Unflattening the "state" part involves consolidating @@ -190,7 +177,7 @@ def _is_zero_dim_tensor(x: Any) -> bool: def _communicate_optim_state( fsdp_param_info: FSDPParamInfo, - flat_param_state: Dict[str, Any], + flat_param_state: dict[str, Any], ) -> _ConsolidatedOptimState: """ Communicates the optimizer state for a flat parameter across ranks. All @@ -262,7 +249,7 @@ def _unflatten_communicated_optim_state( fsdp_param_info: FSDPParamInfo, state: _ConsolidatedOptimState, shard_state: bool, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """ Unflattens the communicated optimizer state (given by ``tensor_state``, ``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flat @@ -283,8 +270,8 @@ def _unflatten_communicated_optim_state( fsdp_state = fsdp_param_info.state handle = fsdp_param_info.handle flat_param = handle.flat_param - unflat_param_state: List[Dict[str, Any]] = [] - flat_param_views: Dict[str, Iterator] = {} + unflat_param_state: list[dict[str, Any]] = [] + flat_param_views: dict[str, Iterator] = {} num_unflat_params = flat_param._num_params tensor_state, zero_dim_tensor_state, non_tensor_state = ( state.tensor_state, @@ -337,10 +324,10 @@ def _unflatten_communicated_optim_state( def _broadcast_processed_state( fsdp_state: _FSDPState, - optim_state: Dict[str, Any], + optim_state: dict[str, Any], group: Optional[dist.ProcessGroup], -) -> Dict[str, Any]: - objects: List[Any] = [None] +) -> dict[str, Any]: + objects: list[Any] = [None] if dist.get_rank(group) == 0: objects[0] = tree_map_only( torch.Tensor, @@ -380,8 +367,8 @@ def _broadcast_state( def _shard_orig_param_state( fsdp_param_info: FSDPParamInfo, fqn: str, - optim_state: Dict[str, Any], -) -> Dict[str, Any]: + optim_state: dict[str, Any], +) -> dict[str, Any]: """ Shard the optimizer state for the original parameter with the name ``fqn``. This API should only be used when ``use_orig_params`` is True. @@ -398,7 +385,7 @@ def _shard_orig_param_state( if not shard_param_info.in_shard: return {} # Flatten and shard the state. - new_optim_state: Dict[str, Any] = {} + new_optim_state: dict[str, Any] = {} intra_param_start_idx = shard_param_info.intra_param_start_idx intra_param_end_idx = shard_param_info.intra_param_end_idx for state_name, value in optim_state.items(): @@ -413,13 +400,13 @@ def _shard_orig_param_state( def _flatten_optim_state_dict( - optim_state_dict: Dict[str, Any], + optim_state_dict: dict[str, Any], model: nn.Module, use_orig_params: bool = False, optim: Optional[torch.optim.Optimizer] = None, rank0_only: bool = False, group: Optional[dist.ProcessGroup] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Flattens the full optimizer state dict, still keying by unflattened parameter names. @@ -461,7 +448,7 @@ def _flatten_optim_state_dict( unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group) # Construct the "state" part - flat_osd_state: Dict[Union[_OptimStateKey, str], Any] = {} + flat_osd_state: dict[Union[_OptimStateKey, str], Any] = {} unflat_osd_state = unflat_osd["state"] all_state_keys = set(unflat_osd_state.keys()) @@ -557,9 +544,9 @@ def _flatten_optim_state_dict( def _flatten_optim_state( fsdp_param_info: FSDPParamInfo, - unflat_osd_state: Dict[str, Dict[str, Any]], - unflat_param_names: List[str], -) -> Dict[str, Any]: + unflat_osd_state: dict[str, dict[str, Any]], + unflat_param_names: list[str], +) -> dict[str, Any]: """ Flattens the optimizer state in ``full_optim_state_dict`` for a single flat parameter in ``fsdp_param_info`` corresponding to the unflattened @@ -629,7 +616,7 @@ def _flatten_optim_state( assert state_names is not None # Flatten the state - flat_state: Dict[str, Optional[torch.Tensor]] = {} + flat_state: dict[str, Optional[torch.Tensor]] = {} for state_name in state_names: state_values = [ unflat_param_state[state_name] if unflat_param_state is not None else None @@ -695,8 +682,8 @@ def _flatten_optim_state( def _flatten_tensor_optim_state( state_name: str, - pos_dim_tensors: List[torch.Tensor], - unflat_param_names: List[str], + pos_dim_tensors: list[torch.Tensor], + unflat_param_names: list[str], unflat_param_shapes: Sequence[torch.Size], handle: FlatParamHandle, ) -> torch.Tensor: @@ -780,8 +767,8 @@ def _flatten_tensor_optim_state( def _flatten_zero_dim_tensor_optim_state( state_name: str, - zero_dim_tensors: List[torch.Tensor], - unflat_param_names: List[str], + zero_dim_tensors: list[torch.Tensor], + unflat_param_names: list[str], ) -> torch.Tensor: """ Flattens the zero-dimension tensor optimizer state given by the values @@ -834,8 +821,8 @@ def _flatten_zero_dim_tensor_optim_state( def _flatten_non_tensor_optim_state( state_name: str, - non_tensors: List[Any], - unflat_param_names: List[str], + non_tensors: list[Any], + unflat_param_names: list[str], ) -> Any: """ Flattens the non-tensor optimizer state given by the values ``non_tensors`` @@ -872,18 +859,18 @@ def _flatten_non_tensor_optim_state( def _rekey_sharded_optim_state_dict( - sharded_osd: Dict[str, Any], + sharded_osd: dict[str, Any], model: nn.Module, optim: torch.optim.Optimizer, optim_input: Optional[ Union[ - List[Dict[str, Any]], + list[dict[str, Any]], Iterable[nn.Parameter], ] ], using_optim_input: bool, is_named_optimizer: bool = False, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Rekeys the optimizer state dict from unflattened parameter names to flat parameter IDs according to the calling rank's ``optim``, which may be @@ -892,8 +879,8 @@ def _rekey_sharded_optim_state_dict( """ param_to_fqns = _get_param_to_fqns(model) flat_param_to_fqn = _get_flat_param_to_fqn(model) - param_to_param_key: Dict[nn.Parameter, Union[int, str]] = cast( - Dict[nn.Parameter, Union[int, str]], + param_to_param_key: dict[nn.Parameter, Union[int, str]] = cast( + dict[nn.Parameter, Union[int, str]], ( _get_param_to_param_id_from_optim_input(model, optim_input) if using_optim_input @@ -907,10 +894,10 @@ def _rekey_sharded_optim_state_dict( # passed to the optimizer assert len(param_to_param_key) <= len(param_to_fqns) - unflat_param_names_to_flat_param_key: Dict[ + unflat_param_names_to_flat_param_key: dict[ tuple[str, ...], Union[int, str] ] = {} # for "state" - unflat_param_name_to_flat_param_key: Dict[ + unflat_param_name_to_flat_param_key: dict[ str, Union[int, str] ] = {} # for "param_groups" for param, unflat_param_names in param_to_fqns.items(): @@ -923,7 +910,7 @@ def _rekey_sharded_optim_state_dict( unflat_param_name_to_flat_param_key[unflat_param_name] = flat_param_key sharded_osd_state = sharded_osd["state"] - rekeyed_osd_state: Dict[Union[str, int], Any] = {} + rekeyed_osd_state: dict[Union[str, int], Any] = {} for key, param_state in sharded_osd_state.items(): if isinstance(key, str): rekeyed_osd_state[key] = param_state @@ -935,7 +922,7 @@ def _rekey_sharded_optim_state_dict( # Only process param_groups if it exists in sharded_osd if "param_groups" in sharded_osd: - rekeyed_osd_param_groups: List[Dict[str, Any]] = [] + rekeyed_osd_param_groups: list[dict[str, Any]] = [] for unflat_param_group in sharded_osd["param_groups"]: flat_param_group = copy.deepcopy(unflat_param_group) flat_param_keys = sorted( @@ -955,11 +942,11 @@ def _get_param_id_to_param_from_optim_input( model: nn.Module, optim_input: Optional[ Union[ - List[Dict[str, Any]], + list[dict[str, Any]], Iterable[nn.Parameter], ] ] = None, -) -> Dict[int, nn.Parameter]: +) -> dict[int, nn.Parameter]: """ Constructs a mapping from parameter IDs to parameters. This may be used both for models with ``FlatParameter`` s and without. @@ -993,7 +980,7 @@ def _get_param_id_to_param_from_optim_input( if optim_input is None: return dict(enumerate(model.parameters())) try: - params = cast(List[nn.Parameter], list(optim_input)) + params = cast(list[nn.Parameter], list(optim_input)) except TypeError as e: raise TypeError( "Optimizer input should be an iterable of Tensors or dicts, " @@ -1013,7 +1000,7 @@ def _get_param_id_to_param_from_optim_input( if all_tensors: return dict(enumerate(params)) assert all_dicts - param_id_to_param: List[nn.Parameter] = [] + param_id_to_param: list[nn.Parameter] = [] for param_group in params: has_params_key = "params" in param_group # type: ignore[operator] assert has_params_key, ( @@ -1026,7 +1013,7 @@ def _get_param_id_to_param_from_optim_input( return dict(enumerate(param_id_to_param)) -def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[FlatParameter, str]: +def _get_flat_param_to_fqn(model: torch.nn.Module) -> dict[FlatParameter, str]: """ Constructs a mapping from ``FlatParameter`` to a cleaned (devoid of prefixes from wrappers) fully qualified name (FQN). Note that this FQN is "non-canonical" @@ -1053,7 +1040,7 @@ def _get_flat_param_to_fqn(model: torch.nn.Module) -> Dict[FlatParameter, str]: def return_fn(flat_param_to_fqn): return flat_param_to_fqn - flat_param_to_fqn_ret: Dict[FlatParameter, str] = {} + flat_param_to_fqn_ret: dict[FlatParameter, str] = {} return _apply_to_modules( model, module_fn, @@ -1067,16 +1054,16 @@ def _get_param_key_to_param( optim: torch.optim.Optimizer, model: Optional[nn.Module] = None, is_named_optimizer: bool = False, - param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None, - flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None, -) -> Dict[Union[int, str], nn.Parameter]: + param_to_fqns: Optional[dict[nn.Parameter, list[str]]] = None, + flat_param_to_fqn: Optional[dict[FlatParameter, str]] = None, +) -> dict[Union[int, str], nn.Parameter]: """ Constructs a mapping from parameter keys to parameters. For the regular optimizers, the keys are parameter IDs. For NamedOptimizer, the keys are FQNs. This API may be used both for models with ``FlatParameter`` s and without. """ - clean_fqn_to_curr_fqn: Dict[str, str] = {} + clean_fqn_to_curr_fqn: dict[str, str] = {} if is_named_optimizer: assert ( param_to_fqns is not None and flat_param_to_fqn is not None @@ -1085,7 +1072,7 @@ def _get_param_key_to_param( for key, _ in _named_parameters_with_duplicates(model): clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key - param_key_to_param: Dict[Union[str, int], nn.Parameter] = {} + param_key_to_param: dict[Union[str, int], nn.Parameter] = {} pid = 0 for param_group in optim.param_groups: if is_named_optimizer: @@ -1118,9 +1105,9 @@ def _get_param_to_param_key( optim: torch.optim.Optimizer, model: Optional[nn.Module] = None, is_named_optimizer: bool = False, - param_to_fqns: Optional[Dict[nn.Parameter, List[str]]] = None, - flat_param_to_fqn: Optional[Dict[FlatParameter, str]] = None, -) -> Dict[nn.Parameter, Union[int, str]]: + param_to_fqns: Optional[dict[nn.Parameter, list[str]]] = None, + flat_param_to_fqn: Optional[dict[FlatParameter, str]] = None, +) -> dict[nn.Parameter, Union[int, str]]: """ Constructs the inverse mapping of :func:`_get_param_key_to_param`. This API only supports the case where `optim` is a regular optimizer, not NamedOptimizer. @@ -1136,25 +1123,25 @@ def _get_param_to_param_id_from_optim_input( model: nn.Module, optim_input: Optional[ Union[ - List[Dict[str, Any]], + list[dict[str, Any]], Iterable[nn.Parameter], ] ] = None, -) -> Dict[nn.Parameter, int]: +) -> dict[nn.Parameter, int]: """Constructs the inverse mapping of :func:`_get_param_id_to_param_from_optim_input`.""" param_id_to_param = _get_param_id_to_param_from_optim_input(model, optim_input) return {param: param_id for param_id, param in param_id_to_param.items()} def _check_missing_keys_on_rank( - r0_optim_state_keys: List[_OptimStateKey], - optim_state_key_to_param_key: Dict[_OptimStateKey, Union[str, int]], - param_key_to_param: Dict[Union[str, int], nn.Parameter], + r0_optim_state_keys: list[_OptimStateKey], + optim_state_key_to_param_key: dict[_OptimStateKey, Union[str, int]], + param_key_to_param: dict[Union[str, int], nn.Parameter], group: Optional[dist.ProcessGroup], ) -> None: # Ensure that all ranks have at least the optimizer states needed by # rank 0's optimizer - missing_keys: List[_OptimStateKey] = [] + missing_keys: list[_OptimStateKey] = [] for r0_optim_state_key in r0_optim_state_keys: if r0_optim_state_key not in optim_state_key_to_param_key: # A parameter from rank 0's optimizer does not exist for this @@ -1179,7 +1166,7 @@ def _check_missing_keys_on_rank( "are missing some of those states" ) for rank, keys in enumerate(obj_list): - keys = cast(List[_OptimStateKey], keys) + keys = cast(list[_OptimStateKey], keys) if len(keys) > 0: error_msg += ( f"\nRank {rank} is missing states for the parameters: " @@ -1189,13 +1176,13 @@ def _check_missing_keys_on_rank( def _map_param_key_to_optim_keys( - optim_state_dict: Dict[str, Any], + optim_state_dict: dict[str, Any], group: Optional[dist.ProcessGroup], - param_key_to_param: Dict[Union[int, str], nn.Parameter], - param_to_fqns: Dict[nn.Parameter, List[str]], - fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo], + param_key_to_param: dict[Union[int, str], nn.Parameter], + param_to_fqns: dict[nn.Parameter, list[str]], + fqn_to_fsdp_param_info: dict[str, FSDPParamInfo], merge_keys: bool = False, -) -> tuple[List[_OptimStateKey], Dict[_OptimStateKey, Union[int, str]]]: +) -> tuple[list[_OptimStateKey], dict[_OptimStateKey, Union[int, str]]]: """ Construct the local mapping between the ``_OptimStateKey`` and parameter keys and all the ``_OptimStateKey`` across ranks. If ``merge_keys`` is False, rank0 @@ -1203,8 +1190,8 @@ def _map_param_key_to_optim_keys( Note that ``merge_keys`` should equal to ``use_orig_params``. """ rank = dist.get_rank(group) - optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]] = {} # local - all_optim_state_keys: List[_OptimStateKey] = [] + optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]] = {} # local + all_optim_state_keys: list[_OptimStateKey] = [] for param_key, param in param_key_to_param.items(): # Do not include parameters without state to avoid empty mappings @@ -1228,7 +1215,7 @@ def _map_param_key_to_optim_keys( optim_state_key_to_param_key[optim_state_key] = param_key if merge_keys: - all_keys: List[List[_OptimStateKey]] = [ + all_keys: list[list[_OptimStateKey]] = [ [] for _ in range(dist.get_world_size(group)) ] dist.all_gather_object(all_keys, all_optim_state_keys, group=group) @@ -1237,7 +1224,7 @@ def _map_param_key_to_optim_keys( ] all_optim_state_keys = sorted(set(merge_all_optim_state_keys)) else: - key_obj_list: List[Optional[List[_OptimStateKey]]] = ( + key_obj_list: list[Optional[list[_OptimStateKey]]] = ( [all_optim_state_keys] if rank == 0 else [None] ) dist.broadcast_object_list(key_obj_list, src=0, group=group) @@ -1254,11 +1241,11 @@ def _map_param_key_to_optim_keys( def _unflatten_param_groups( - state_dict: Dict[str, Any], - param_key_to_param: Dict[Union[int, str], nn.Parameter], - param_to_fqns: Dict[nn.Parameter, List[str]], -) -> List[Dict[str, Any]]: - param_groups: List[Dict[str, Any]] = [] + state_dict: dict[str, Any], + param_key_to_param: dict[Union[int, str], nn.Parameter], + param_to_fqns: dict[nn.Parameter, list[str]], +) -> list[dict[str, Any]]: + param_groups: list[dict[str, Any]] = [] for flat_param_group in state_dict["param_groups"]: unflat_param_group = copy.deepcopy(flat_param_group) param_group_params = [ @@ -1277,7 +1264,7 @@ def _unflatten_param_groups( return param_groups -def _is_named_optimizer(optim_state_dict: Dict[str, Any]) -> bool: +def _is_named_optimizer(optim_state_dict: dict[str, Any]) -> bool: """ Returns whether the state_dict is from a NamedOptimizer. This function checks that the keys in the state_dict['state'] are strings @@ -1299,22 +1286,22 @@ def _is_named_optimizer(optim_state_dict: Dict[str, Any]) -> bool: @dataclass class StateInfo: # The key of these dictionaries are the state name, e.g., `exp_avg`. - tensors: Dict[str, _PosDimTensorInfo] - scalar_tensors: Dict[str, torch.Tensor] - non_tensors: Dict[str, Any] + tensors: dict[str, _PosDimTensorInfo] + scalar_tensors: dict[str, torch.Tensor] + non_tensors: dict[str, Any] def _allgather_state_info( fsdp_state: _FSDPState, - input_states: Dict[str, Any], -) -> List[Dict[str, StateInfo]]: + input_states: dict[str, Any], +) -> list[dict[str, StateInfo]]: """ Given the ``input_states``, allgather StateInfo for each state. The function uses all_gather_object to gather StateInfo so no GPU tensors are sent. """ - processed_state_dict: Dict[str, StateInfo] = {} - gathered_state_info: List[Dict[str, StateInfo]] = [ + processed_state_dict: dict[str, StateInfo] = {} + gathered_state_info: list[dict[str, StateInfo]] = [ {} for _ in range(fsdp_state.world_size) ] @@ -1343,10 +1330,10 @@ def _allgather_state_info( def _convert_all_state_info( fsdp_param_info: FSDPParamInfo, - gathered_state_info: List[Dict[str, StateInfo]], - input_states: Dict[str, Any], - output_states: Dict[str, Dict[str, Any]], -) -> tuple[Optional[torch.dtype], Dict[str, List[Optional[torch.Tensor]]]]: + gathered_state_info: list[dict[str, StateInfo]], + input_states: dict[str, Any], + output_states: dict[str, dict[str, Any]], +) -> tuple[Optional[torch.dtype], dict[str, list[Optional[torch.Tensor]]]]: """ Given the ``gathered_state_info`` and ``input_states``, the API converted the StateInfo into the original state if the state is not a non-scalar @@ -1354,20 +1341,20 @@ def _convert_all_state_info( ``state_buffer`` in a correct order for later allgather purpose. """ - state_buffers: Dict[str, List[Optional[torch.Tensor]]] = {} + state_buffers: dict[str, list[Optional[torch.Tensor]]] = {} for fqn, gathered_state in output_states.items(): state_info = [s[fqn] for s in gathered_state_info] all_tensor_states = sorted( {n for state in state_info for n in state.tensors.keys()} ) - empty_ranks: Set[int] = set() + empty_ranks: set[int] = set() dtype: Optional[torch.dtype] = None # First check all the non-scalar states and get the information of # states on each rank. for state_name in all_tensor_states: numels = [] - _empty_ranks: Set[int] = set() + _empty_ranks: set[int] = set() for rank, object_state in enumerate(state_info): numels.append(0) info = object_state.tensors.get(state_name, None) @@ -1426,7 +1413,7 @@ def _convert_all_state_info( def _unflatten_orig_param_states( fsdp_param_info: FSDPParamInfo, - output_states: Dict[str, Dict[str, Any]], + output_states: dict[str, dict[str, Any]], state_name: str, shard_state: bool, to_save: bool, @@ -1498,12 +1485,12 @@ def _unflatten_orig_param_states( def _allgather_orig_param_states( fsdp_param_info: FSDPParamInfo, - gathered_state_info: List[Dict[str, StateInfo]], - input_states: Dict[str, Any], + gathered_state_info: list[dict[str, StateInfo]], + input_states: dict[str, Any], shard_state: bool, to_save: bool, cpu_offload: bool, -) -> Dict[str, Dict[str, Any]]: +) -> dict[str, dict[str, Any]]: """ Given the ``gathered_state_info`` and ``input_states``, the API allgathers all tensor states and restore non-tensor states from ``gathered_state_info``. @@ -1515,7 +1502,7 @@ def _allgather_orig_param_states( fsdp_state._device_handle.memory_summary(), ) - output_states: Dict[str, Dict[str, Any]] = {fqn: {} for fqn in input_states.keys()} + output_states: dict[str, dict[str, Any]] = {fqn: {} for fqn in input_states.keys()} dtype, state_buffers = _convert_all_state_info( fsdp_param_info, gathered_state_info, input_states, output_states @@ -1524,7 +1511,7 @@ def _allgather_orig_param_states( if len(state_buffers) == 0: return output_states - has_state_params: List[bool] = [ + has_state_params: list[bool] = [ True if fqn in output_states else False for fqn, idx in fsdp_param_info.param_indices.items() ] @@ -1545,7 +1532,7 @@ def _allgather_orig_param_states( # Synchronize can be slow but this will be easier for us to debug. fsdp_state._device_handle.synchronize() for state_name, buffers in state_buffers.items(): - local_buffers: List[torch.Tensor] = [] + local_buffers: list[torch.Tensor] = [] begin = fsdp_state.rank * flat_param._sharded_size.numel() # End is inclusive. end = begin + flat_param._sharded_size.numel() - 1 @@ -1666,11 +1653,11 @@ def _allgather_orig_param_states( def _gather_all_orig_param_state( fsdp_param_info: FSDPParamInfo, - input_states: Dict[str, Any], + input_states: dict[str, Any], shard_state: bool, to_save: bool, cpu_offload: bool, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Given a optimizer state dict, ``input_states``, which the keys are FQNs to the original parameters (not FlatParameters nor parmeter ID), gather all the @@ -1715,20 +1702,20 @@ def _gather_all_orig_param_state( def _convert_state_with_orig_params( - all_optim_state_keys: List[_OptimStateKey], - optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]], - fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo], - optim_state_dict: Dict[Union[str, int], Any], + all_optim_state_keys: list[_OptimStateKey], + optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]], + fqn_to_fsdp_param_info: dict[str, FSDPParamInfo], + optim_state_dict: dict[Union[str, int], Any], to_save: bool, shard_state: bool, cpu_offload: bool = True, -) -> Dict[str, Any]: - fsdp_osd_state: Dict[str, Any] = {} +) -> dict[str, Any]: + fsdp_osd_state: dict[str, Any] = {} # This variable is used to deduplicate the FSDPParamInfo as one FSDPParamInfo # usually corresponds to multiple parameters. We could not use FSDPParamInfo # as the key because FSDPParamInfo is not hashable. As a result, we fall back # to `id(FSDPParamInfo)`, which the type is an integer. - all_states: Dict[int, Dict[str, Any]] = {} + all_states: dict[int, dict[str, Any]] = {} # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers # across ranks for optim_state_key in all_optim_state_keys: @@ -1806,15 +1793,15 @@ def _convert_state_with_orig_params( def _convert_state_with_flat_params( - all_optim_state_keys: List[_OptimStateKey], - optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]], - fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo], - optim_state_dict: Dict[Union[str, int], Any], + all_optim_state_keys: list[_OptimStateKey], + optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]], + fqn_to_fsdp_param_info: dict[str, FSDPParamInfo], + optim_state_dict: dict[Union[str, int], Any], to_save: bool, shard_state: bool, cpu_offload: bool = True, -) -> Dict[str, Any]: - fsdp_osd_state: Dict[str, Any] = {} +) -> dict[str, Any]: + fsdp_osd_state: dict[str, Any] = {} # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers # across ranks for optim_state_key in all_optim_state_keys: @@ -1866,10 +1853,10 @@ def _convert_state_with_flat_params( def _optim_state_dict( model: nn.Module, optim: torch.optim.Optimizer, - optim_state_dict: Dict[str, Any], + optim_state_dict: dict[str, Any], optim_input: Optional[ Union[ - List[Dict[str, Any]], + list[dict[str, Any]], Iterable[nn.Parameter], ] ], @@ -1879,7 +1866,7 @@ def _optim_state_dict( using_optim_input: bool, use_orig_params: bool = False, cpu_offload: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Consolidates the optimizer state and returns it as a :class:`dict` following the convention of :meth:`torch.optim.Optimizer.state_dict`, @@ -1940,7 +1927,7 @@ def _optim_state_dict( is_named_optimizer = _is_named_optimizer(optim_state_dict) param_key_to_param = cast( - Dict[Union[int, str], nn.Parameter], + dict[Union[int, str], nn.Parameter], ( _get_param_id_to_param_from_optim_input(model, optim_input) if using_optim_input @@ -1985,7 +1972,7 @@ def _optim_state_dict( if not to_save: return {} - fsdp_osd: Dict[str, Any] = {"state": fsdp_osd_state} + fsdp_osd: dict[str, Any] = {"state": fsdp_osd_state} flat_param_fqns = set(flat_param_to_fqn.values()) for key, value in optim_state_dict["state"].items(): @@ -2019,7 +2006,7 @@ def _optim_state_dict( return fsdp_osd -def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]: +def _get_fqn_to_fsdp_param_info(model: nn.Module) -> dict[str, FSDPParamInfo]: """ Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo`` if the param is managed by FSDP. Shared parameters, or original parameters that @@ -2055,7 +2042,7 @@ def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]: def return_fn(fqn_to_param_info): return fqn_to_param_info - fqn_to_param_info: Dict[str, FSDPParamInfo] = {} + fqn_to_param_info: dict[str, FSDPParamInfo] = {} # FlatParameter._fqns stores the local fqn, starting from the root of the # FSDP. Using _apply_to_modules() with model (may not be the FSDP root # module) allows us to construct the global fqn. diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index 2f3c9bb0c00..da8ae59724b 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -2,7 +2,7 @@ import functools import logging from enum import auto, Enum -from typing import Any, Callable, Dict, List, no_type_check, Optional, Set +from typing import Any, Callable, no_type_check, Optional import torch import torch.distributed as dist @@ -57,7 +57,7 @@ class _PrefetchMode(Enum): def _get_fsdp_root_states_with_modules( module: nn.Module, -) -> tuple[List[_FSDPState], List[nn.Module]]: +) -> tuple[list[_FSDPState], list[nn.Module]]: """ Returns a tuple containing: 1. A list of the root ``_FSDPState`` instances in the module tree rooted at @@ -70,9 +70,9 @@ def _get_fsdp_root_states_with_modules( must call :func:`_is_fsdp_root` to force a lazy initialization to determine the FSDP root in case lazy initialization has not yet happened. """ - fsdp_root_states: List[_FSDPState] = [] - fsdp_root_modules: List[nn.Module] = [] - visited_fsdp_states: Set[_FSDPState] = set() + fsdp_root_states: list[_FSDPState] = [] + fsdp_root_modules: list[nn.Module] = [] + visited_fsdp_states: set[_FSDPState] = set() # NOTE: This function assumes that `module.modules()` proceeds top-down. for submodule in module.modules(): optional_state = _get_module_fsdp_state(submodule) @@ -87,7 +87,7 @@ def _get_fsdp_root_states_with_modules( return fsdp_root_states, fsdp_root_modules -def _get_fsdp_root_states(module: nn.Module) -> List[_FSDPState]: +def _get_fsdp_root_states(module: nn.Module) -> list[_FSDPState]: """See :func:`_get_fsdp_root_states_with_modules`.""" fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module) return fsdp_root_states @@ -178,7 +178,7 @@ def _share_state_and_init_handle_attrs( handle = root_state._handle if handle: handle.init_flat_param_attributes() - attr_name_to_values: Dict[str, Set[Any]] = {} + attr_name_to_values: dict[str, set[Any]] = {} for attr_name in HOMOGENEOUS_ATTR_NAMES: attr_name_to_values[attr_name] = set() root_state._all_handles = root_state._exec_order_data.all_handles # share reference @@ -347,8 +347,8 @@ def _pre_forward( unshard_fn: Callable, module: nn.Module, args: tuple[Any, ...], - kwargs: Dict[str, Any], -) -> tuple[tuple[Any, ...], Dict[str, Any]]: + kwargs: dict[str, Any], +) -> tuple[tuple[Any, ...], dict[str, Any]]: """ Runs the pre-forward logic. This includes an opportunity to unshard currently sharded parameters such as those for the current forward and @@ -1469,7 +1469,7 @@ def _register_post_backward_reshard_only_hook( state: _FSDPState, handle: Optional[FlatParamHandle], args: tuple[Any, ...], - kwargs: Dict[str, Any], + kwargs: dict[str, Any], ) -> None: """ Registers post-backward hooks to reshard flat parameters that do not @@ -1483,7 +1483,7 @@ def _register_post_backward_reshard_only_hook( return # Construct `inp_tensors` lazily to avoid CPU overhead in typical case # where each flat parameter requires gradient - inp_tensors: Optional[List[torch.Tensor]] = None + inp_tensors: Optional[list[torch.Tensor]] = None if not handle: return flat_param = handle.flat_param @@ -1555,7 +1555,7 @@ def _wait_for_computation_stream( def _reset_flat_param_grad_info_if_needed( - handles: List[FlatParamHandle], + handles: list[FlatParamHandle], ): """ Clears the original parameters' gradients if needed. This method's CPU @@ -1573,7 +1573,7 @@ def _reset_flat_param_grad_info_if_needed( def _get_buffers_and_dtypes_for_computation( state: _FSDPState, root_module: nn.Module, -) -> tuple[List[torch.Tensor], List[Optional[torch.dtype]]]: +) -> tuple[list[torch.Tensor], list[Optional[torch.dtype]]]: """ Returns all buffers in the module tree rooted at ``root_module`` and a corresponding list of the buffer dtypes for computation. Each buffer dtype @@ -1581,9 +1581,9 @@ def _get_buffers_and_dtypes_for_computation( low precision dtype otherwise. """ _p_assert(state._is_root, "Expects the root to cast buffers") - buffers: List[torch.Tensor] = [] - buffer_dtypes: List[Optional[torch.dtype]] = [] - visited_buffers: Set[torch.Tensor] = set() + buffers: list[torch.Tensor] = [] + buffer_dtypes: list[Optional[torch.dtype]] = [] + visited_buffers: set[torch.Tensor] = set() # Traverse the FSDP states bottom-up so that we prefer the owning FSDP # instance's mixed precision setting for each buffer fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules( @@ -1605,12 +1605,12 @@ def _get_buffers_and_dtypes_for_computation( @no_type_check def _get_orig_buffer_dtypes( state: _FSDPState, - buffer_names: List[str], -) -> List[torch.dtype]: + buffer_names: list[str], +) -> list[torch.dtype]: """ Returns the original buffer types of the given buffer names. """ - buffer_dtypes: List[torch.dtype] = [] + buffer_dtypes: list[torch.dtype] = [] for buffer_name in buffer_names: _p_assert( buffer_name in state._buffer_name_to_orig_dtype, @@ -1623,8 +1623,8 @@ def _get_orig_buffer_dtypes( def _cast_buffers_to_dtype_and_device( - buffers: List[torch.Tensor], - buffer_dtypes: List[Optional[torch.dtype]], + buffers: list[torch.Tensor], + buffer_dtypes: list[Optional[torch.dtype]], device: torch.device, ) -> None: """ diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 5531b8328e3..72b4b60b3f9 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -3,7 +3,8 @@ import contextlib import logging import math import warnings -from typing import Any, Callable, cast, Dict, Generator, Iterator, List, no_type_check +from collections.abc import Generator, Iterator +from typing import Any, Callable, cast, no_type_check import torch import torch.distributed as dist @@ -170,10 +171,10 @@ def _common_unshard_pre_state_dict_hook( def _common_unshard_post_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, param_hook: Callable, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ The post-state_dict flow that shared by all state_dict types that require ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this @@ -302,9 +303,9 @@ def _full_pre_state_dict_hook( def _full_post_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Hook that runs after model.state_dict() is called before returning result to user. For FSDP, we may have to clone the tensors in state_dict as params go @@ -313,7 +314,7 @@ def _full_post_state_dict_hook( """ def param_hook( - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, fqn: str, ) -> None: @@ -347,7 +348,7 @@ def _full_post_state_dict_hook( def _full_pre_load_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, ) -> None: _lazy_init(fsdp_state, module) @@ -393,9 +394,9 @@ def _local_pre_state_dict_hook( def _local_post_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ This hook create a ShardedTensor from the local flat_param and replace the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy @@ -448,7 +449,7 @@ def _local_post_load_state_dict_hook( def _local_pre_load_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, ) -> None: """ @@ -526,15 +527,15 @@ def _sharded_pre_state_dict_hook( def _sharded_post_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ The hook replaces the unflattened, unsharded parameter in the state_dict with a unflattened, sharded parameter (a ShardedTensor). """ - def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str): + def param_hook(state_dict: dict[str, Any], prefix: str, fqn: str): param = state_dict[fqn] if not fsdp_state._state_dict_config._use_dtensor: sharded_tensor = _ext_chunk_tensor( @@ -574,7 +575,7 @@ def _sharded_post_load_state_dict_hook( def _sharded_pre_load_state_dict_hook( module: nn.Module, fsdp_state: _FSDPState, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, ) -> None: """ @@ -686,10 +687,10 @@ def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator: @torch.no_grad() def _post_state_dict_hook( module: nn.Module, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, *args: Any, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide @@ -802,7 +803,7 @@ def _set_use_dtensor(fsdp_state: _FSDPState) -> None: @torch.no_grad() def _pre_load_state_dict_hook( module: nn.Module, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, *args: Any, ) -> None: @@ -845,7 +846,7 @@ def _pre_load_state_dict_hook( @torch.no_grad() def _post_load_state_dict_hook( module: nn.Module, - incompatible_keys: tuple[List[str], List[str]], + incompatible_keys: tuple[list[str], list[str]], *args: Any, ) -> None: fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) @@ -906,7 +907,7 @@ def _register_state_dict_hooks_base( state: _FSDPState, hook_registration_fn_name: str, hook: Callable, - hook_registration_fn_kwargs: Dict[str, Any], + hook_registration_fn_kwargs: dict[str, Any], ) -> None: """Registers ``hook`` using ``hook_registration_fn``.""" if not _is_composable(state): diff --git a/torch/distributed/fsdp/_trace_utils.py b/torch/distributed/fsdp/_trace_utils.py index 0cc69a3c4e6..fcd09f6ce9f 100644 --- a/torch/distributed/fsdp/_trace_utils.py +++ b/torch/distributed/fsdp/_trace_utils.py @@ -2,7 +2,7 @@ import functools from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set +from typing import Any, Callable, NamedTuple, Optional import torch import torch.nn as nn @@ -29,7 +29,7 @@ class TracingConfig: """ tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer) - concrete_args: Optional[Dict[str, Any]] = None + concrete_args: Optional[dict[str, Any]] = None class _ParamUsageInfo(NamedTuple): @@ -52,7 +52,7 @@ class _ParamUsageInfo(NamedTuple): """ module: nn.Module - named_params: List[tuple[str, nn.Parameter]] + named_params: list[tuple[str, nn.Parameter]] class _ExecutionInfo: @@ -79,12 +79,12 @@ class _ExecutionInfo: def __init__(self, root_module: nn.Module) -> None: self.curr_module: nn.Module = root_module - self.module_forward_order: List[nn.Module] = [root_module] - self.module_to_param_usage_infos: Dict[nn.Module, List[_ParamUsageInfo]] = { + self.module_forward_order: list[nn.Module] = [root_module] + self.module_to_param_usage_infos: dict[nn.Module, list[_ParamUsageInfo]] = { root_module: [] } - self.param_forward_order: List[nn.Parameter] = [] - self.visited_params: Set[nn.Parameter] = set() + self.param_forward_order: list[nn.Parameter] = [] + self.visited_params: set[nn.Parameter] = set() class _ExecOrderTracer: @@ -120,7 +120,7 @@ class _ExecOrderTracer: module: nn.Module, forward: Callable, args: tuple[Any, ...], - kwargs: Dict[str, Any], + kwargs: dict[str, Any], ) -> Any: """ Overrides ``call_module`` to save execution information to @@ -160,12 +160,12 @@ class _ExecOrderTracer: self, create_proxy: Callable, exec_info: _ExecutionInfo, - fqn_to_param: Dict[str, nn.Parameter], + fqn_to_param: dict[str, nn.Parameter], # Below are the expected arguments to `create_proxy()` kind: str, target: torch.fx.node.Target, args: tuple[Any, ...], - kwargs: Dict[str, Any], + kwargs: dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None, proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None, @@ -210,7 +210,7 @@ class _ExecOrderTracer: curr_module = exec_info.curr_module if kind in ("call_function", "call_method"): if args is not None: - named_params: List[tuple[str, nn.Parameter]] = [] + named_params: list[tuple[str, nn.Parameter]] = [] for arg in args: if ( isinstance(arg, torch.fx.Proxy) diff --git a/torch/distributed/fsdp/_traversal_utils.py b/torch/distributed/fsdp/_traversal_utils.py index 1873f1af238..5ca758c83a9 100644 --- a/torch/distributed/fsdp/_traversal_utils.py +++ b/torch/distributed/fsdp/_traversal_utils.py @@ -6,7 +6,6 @@ imports. For brevity, we may import the file as ``traversal_utils``. """ import collections -from typing import Deque, List, Set import torch.nn as nn from torch.distributed._composable.contract import _get_registry @@ -48,7 +47,7 @@ def _composable(module: nn.Module) -> bool: # `FlatParameter` registration, which is not needed for `use_orig_params=True`. def _get_fsdp_states_with_modules( module: nn.Module, -) -> tuple[List[_FSDPState], List[nn.Module]]: +) -> tuple[list[_FSDPState], list[nn.Module]]: """ Returns a tuple containing: 1. A list of the ``_FSDPState`` instances in the module tree rooted at @@ -65,19 +64,19 @@ def _get_fsdp_states_with_modules( NOTE: The traversal does not proceed into any module annotated by an incompatible API (e.g. ``replicate``). """ - fsdp_states: List[_FSDPState] = [] - fsdp_modules: List[nn.Module] = [] + fsdp_states: list[_FSDPState] = [] + fsdp_modules: list[nn.Module] = [] # Track the visited FSDP states since multiple modules may share the same # one and we want to return a de-duplicated list - visited_fsdp_states: Set[_FSDPState] = set() + visited_fsdp_states: set[_FSDPState] = set() # Track the visited modules in case of shared modules, which implies the # module graph is no longer a tree - visited_modules: Set[nn.Module] = set() + visited_modules: set[nn.Module] = set() # Perform depth-first search from `module` to ensure that we do not # traverse into an incompatible API's subtree (use DFS instead of BFS to # match `.modules()` order) - deque: Deque[nn.Module] = collections.deque([module]) + deque: collections.deque[nn.Module] = collections.deque([module]) while deque: submodule = deque.popleft() visited_modules.add(submodule) @@ -94,13 +93,13 @@ def _get_fsdp_states_with_modules( return fsdp_states, fsdp_modules -def _get_fsdp_states(module: nn.Module) -> List[_FSDPState]: +def _get_fsdp_states(module: nn.Module) -> list[_FSDPState]: """See :func:`_get_fsdp_states_with_modules`.""" fsdp_states, _ = _get_fsdp_states_with_modules(module) return fsdp_states -def _get_fsdp_handles(module: nn.Module) -> List: +def _get_fsdp_handles(module: nn.Module) -> list: """ Returns all ``FlatParamHandle`` s in the module tree rooted at ``module`` following the rules in :func:`_get_fsdp_state`. diff --git a/torch/distributed/fsdp/_unshard_param_utils.py b/torch/distributed/fsdp/_unshard_param_utils.py index 4143d2928c8..ad495c73426 100644 --- a/torch/distributed/fsdp/_unshard_param_utils.py +++ b/torch/distributed/fsdp/_unshard_param_utils.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import contextlib import warnings -from typing import cast, Generator +from collections.abc import Generator +from typing import cast import torch import torch.distributed.fsdp._traversal_utils as traversal_utils diff --git a/torch/distributed/fsdp/_wrap_utils.py b/torch/distributed/fsdp/_wrap_utils.py index f480006b288..ceecabcacf7 100644 --- a/torch/distributed/fsdp/_wrap_utils.py +++ b/torch/distributed/fsdp/_wrap_utils.py @@ -4,7 +4,7 @@ import functools import inspect import warnings from functools import partial -from typing import Any, Callable, Dict, List, Set, Type, Union +from typing import Any, Callable, Union import torch.nn as nn from torch.distributed.fsdp._common_utils import ( @@ -25,9 +25,9 @@ from torch.distributed.fsdp.wrap import ( def _auto_wrap( root_module: nn.Module, policy: Union[Callable, _Policy], - ignored_modules: Set[nn.Module], - ignored_params: Set[nn.Parameter], - root_kwargs: Dict[str, Any], + ignored_modules: set[nn.Module], + ignored_params: set[nn.Parameter], + root_kwargs: dict[str, Any], fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard` ): """ @@ -111,7 +111,7 @@ def _check_nested_wrapping(root_module: nn.Module): def _warn_on_overridden_mixed_precision( - overridden_module_classes: Set[Type[nn.Module]], + overridden_module_classes: set[type[nn.Module]], ): if len(overridden_module_classes) == 0: return @@ -125,8 +125,8 @@ def _warn_on_overridden_mixed_precision( def _validate_frozen_params( root_module: nn.Module, - modules_to_wrap: Set[nn.Module], - ignored_params: Set[nn.Parameter], + modules_to_wrap: set[nn.Module], + ignored_params: set[nn.Parameter], use_orig_params: bool, ): """ @@ -136,15 +136,15 @@ def _validate_frozen_params( recommended for ``use_orig_params=True`` (user warning). """ post_order_named_modules = _get_post_order_named_modules(root_module) - visited_modules: Set[nn.Module] = set() + visited_modules: set[nn.Module] = set() for module_name, module in post_order_named_modules: if module in modules_to_wrap: param_to_fqn = _get_managed_param_to_fqn( module, ignored_params, visited_modules, module_name ) - frozen_param_fqns: List[str] = [] + frozen_param_fqns: list[str] = [] frozen_param_numel = 0 - nonfrozen_param_fqns: List[str] = [] + nonfrozen_param_fqns: list[str] = [] nonfrozen_param_numel = 0 for param, fqn in param_to_fqn.items(): if param.requires_grad: @@ -178,7 +178,7 @@ def _validate_frozen_params( def _get_post_order_named_modules( root_module: nn.Module, -) -> List[tuple[str, nn.Module]]: +) -> list[tuple[str, nn.Module]]: """ This returns the named modules following a post-order traversal, which is a valid reverse topological sort. We achieve this using the reverse of a @@ -202,7 +202,7 @@ def _get_post_order_named_modules( visited_modules = {root_module} stack = [("", root_module)] # Append and reverse at the end for linear-time algorithm - reverse_post_order_named_modules: List[tuple[str, nn.Module]] = [] + reverse_post_order_named_modules: list[tuple[str, nn.Module]] = [] while stack: module_name, module = stack.pop() reverse_post_order_named_modules.append((module_name, module)) @@ -220,10 +220,10 @@ def _get_post_order_named_modules( def _get_managed_param_to_fqn( module_to_wrap: nn.Module, - ignored_params: Set[nn.Parameter], - visited_modules: Set[nn.Module], + ignored_params: set[nn.Parameter], + visited_modules: set[nn.Module], root_prefix: str, -) -> Dict[nn.Parameter, str]: +) -> dict[nn.Parameter, str]: """ This returns a dict that maps managed parameter to its FQN for the given ``module_to_wrap``. The dict's keys are exactly the parameters that would @@ -238,7 +238,7 @@ def _get_managed_param_to_fqn( on the full module tree in one shot. Given those differences, we do not try to unify the two. """ - param_to_fqn: Dict[nn.Parameter, str] = {} + param_to_fqn: dict[nn.Parameter, str] = {} # Run BFS (or any tree traversal works) queue = collections.deque([(module_to_wrap, root_prefix)]) visited_modules.add(module_to_wrap) diff --git a/torch/distributed/fsdp/api.py b/torch/distributed/fsdp/api.py index c1d24413fdb..7282fbcd7b5 100644 --- a/torch/distributed/fsdp/api.py +++ b/torch/distributed/fsdp/api.py @@ -3,9 +3,10 @@ This file includes public APIs for FSDP such as the classes used for the constructor arguments. """ +from collections.abc import Sequence from dataclasses import dataclass from enum import auto, Enum -from typing import Optional, Sequence, Type +from typing import Optional import torch from torch.nn.modules.batchnorm import _BatchNorm @@ -223,7 +224,7 @@ class MixedPrecision: keep_low_precision_grads: bool = False cast_forward_inputs: bool = False cast_root_forward_inputs: bool = True - _module_classes_to_ignore: Sequence[Type[torch.nn.Module]] = (_BatchNorm,) + _module_classes_to_ignore: Sequence[type[torch.nn.Module]] = (_BatchNorm,) @dataclass diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 735bc04bc52..d1de71e5cc9 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -6,19 +6,10 @@ import functools import math import traceback import warnings +from collections.abc import Generator, Iterable, Iterator from contextlib import contextmanager from enum import auto, Enum -from typing import ( - Any, - Callable, - Dict, - Generator, - Iterable, - Iterator, - List, - Optional, - Union, -) +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist @@ -562,7 +553,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): def fsdp_modules( module: nn.Module, root_only: bool = False, - ) -> List["FullyShardedDataParallel"]: + ) -> list["FullyShardedDataParallel"]: """Return all nested FSDP instances. This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``. @@ -1013,7 +1004,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): param_name = param_name.replace(FSDP_PREFIX, "") yield (param_name, param) - def _assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None: + def _assert_state(self, state: Union[TrainingState, list[TrainingState]]) -> None: """Assert we are in the given state.""" # Since assert can be turned off and this error checking # is really important, we use explicit error checking @@ -1135,7 +1126,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): # iteration order and hence deterministic total norm computation sharded_params = [] nonsharded_params = [] - grads: List[torch.Tensor] = [] + grads: list[torch.Tensor] = [] for handle in self._all_handles: if handle.uses_sharded_strategy: target_set = sharded_params_set @@ -1257,10 +1248,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): def _optim_state_dict_impl( model: torch.nn.Module, optim: torch.optim.Optimizer, - optim_state_dict: Dict[str, Any], + optim_state_dict: dict[str, Any], optim_input: Optional[ Union[ - List[Dict[str, Any]], + list[dict[str, Any]], Iterable[torch.nn.Parameter], ] ] = None, @@ -1270,7 +1261,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): cpu_offload: bool = True, *, _stacklevel: int = 1, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Transform the state-dict of an optimizer corresponding to a sharded model. This is the internal API that is used by all the optim_state_dict implementations. @@ -1312,11 +1303,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): @staticmethod def _optim_state_dict_to_load_impl( - optim_state_dict: Dict[str, Any], + optim_state_dict: dict[str, Any], model: torch.nn.Module, optim_input: Optional[ Union[ - List[Dict[str, Any]], + list[dict[str, Any]], Iterable[torch.nn.Parameter], ] ] = None, @@ -1325,7 +1316,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): rank0_only: bool = False, is_named_optimizer: bool = False, group: Optional[dist.ProcessGroup] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. @@ -1376,13 +1367,13 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): optim: torch.optim.Optimizer, optim_input: Optional[ Union[ - List[Dict[str, Any]], + list[dict[str, Any]], Iterable[torch.nn.Parameter], ] ] = None, rank0_only: bool = True, group: Optional[dist.ProcessGroup] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Return the full optimizer state-dict. Consolidates the full optimizer state on rank 0 and returns it @@ -1451,7 +1442,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): model: torch.nn.Module, optim: torch.optim.Optimizer, group: Optional[dist.ProcessGroup] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Return the optimizer state-dict in its sharded form. The API is similar to :meth:`full_optim_state_dict` but this API chunks @@ -1482,16 +1473,16 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): @staticmethod def shard_full_optim_state_dict( - full_optim_state_dict: Dict[str, Any], + full_optim_state_dict: dict[str, Any], model: torch.nn.Module, optim_input: Optional[ Union[ - List[Dict[str, Any]], + list[dict[str, Any]], Iterable[torch.nn.Parameter], ] ] = None, optim: Optional[torch.optim.Optimizer] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Shard a full optimizer state-dict. Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened @@ -1561,10 +1552,10 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): @staticmethod def flatten_sharded_optim_state_dict( - sharded_optim_state_dict: Dict[str, Any], + sharded_optim_state_dict: dict[str, Any], model: torch.nn.Module, optim: torch.optim.Optimizer, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Flatten a sharded optimizer state-dict. The API is similar to :meth:`shard_full_optim_state_dict`. The only @@ -1600,17 +1591,17 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): @staticmethod def scatter_full_optim_state_dict( - full_optim_state_dict: Optional[Dict[str, Any]], + full_optim_state_dict: Optional[dict[str, Any]], model: torch.nn.Module, optim_input: Optional[ Union[ - List[Dict[str, Any]], + list[dict[str, Any]], Iterable[torch.nn.Parameter], ] ] = None, optim: Optional[torch.optim.Optimizer] = None, group: Optional[Any] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Scatter the full optimizer state dict from rank 0 to all other ranks. Returns the sharded optimizer state dict on each rank. @@ -1684,17 +1675,17 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): @staticmethod def rekey_optim_state_dict( - optim_state_dict: Dict[str, Any], + optim_state_dict: dict[str, Any], optim_state_key_type: OptimStateKeyType, model: torch.nn.Module, optim_input: Optional[ Union[ - List[Dict[str, Any]], + list[dict[str, Any]], Iterable[torch.nn.Parameter], ] ] = None, optim: Optional[torch.optim.Optimizer] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``. This can be used to achieve compatibility between optimizer state dicts from models with FSDP @@ -1762,7 +1753,7 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): else _get_param_key_to_param(optim) ) param_to_param_name = _get_param_to_fqn(model) - param_id_to_param_name: List[str] = [ + param_id_to_param_name: list[str] = [ param_to_param_name[param] for param in param_id_to_param.values() ] new_osd["state"] = { @@ -1811,9 +1802,9 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): def optim_state_dict( model: torch.nn.Module, optim: torch.optim.Optimizer, - optim_state_dict: Optional[Dict[str, Any]] = None, + optim_state_dict: Optional[dict[str, Any]] = None, group: Optional[dist.ProcessGroup] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Transform the state-dict of an optimizer corresponding to a sharded model. @@ -1907,11 +1898,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState): def optim_state_dict_to_load( model: torch.nn.Module, optim: torch.optim.Optimizer, - optim_state_dict: Dict[str, Any], + optim_state_dict: dict[str, Any], is_named_optimizer: bool = False, load_directly: bool = False, group: Optional[dist.ProcessGroup] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. @@ -2146,7 +2137,7 @@ def _get_grad_norm( def _get_param_to_fqn( model: torch.nn.Module, -) -> Dict[torch.nn.Parameter, str]: +) -> dict[torch.nn.Parameter, str]: """ Construct a mapping from parameters to their parameter names. @@ -2178,7 +2169,7 @@ def _get_param_to_fqn( def _get_fqn_to_param( model: torch.nn.Module, -) -> Dict[str, torch.nn.Parameter]: +) -> dict[str, torch.nn.Parameter]: """Construct the inverse mapping of :meth:`_get_param_to_fqn`.""" param_to_param_name = _get_param_to_fqn(model) return dict(zip(param_to_param_name.values(), param_to_param_name.keys())) diff --git a/torch/distributed/fsdp/sharded_grad_scaler.py b/torch/distributed/fsdp/sharded_grad_scaler.py index c51c428238e..d19cb720543 100644 --- a/torch/distributed/fsdp/sharded_grad_scaler.py +++ b/torch/distributed/fsdp/sharded_grad_scaler.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import logging from collections import abc, defaultdict -from typing import Any, Dict, Iterable, List, Optional, overload, Union +from collections.abc import Iterable +from typing import Any, Optional, overload, Union import torch import torch.distributed as dist @@ -12,7 +13,7 @@ from torch.distributed.distributed_c10d import ProcessGroup logger = logging.getLogger(__name__) -def _refresh_per_optimizer_state() -> Dict[str, Any]: +def _refresh_per_optimizer_state() -> dict[str, Any]: return {"stage": OptState.READY, "found_inf_per_device": {}} @@ -36,7 +37,7 @@ class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator): def __init__(self, master_tensor: torch.Tensor) -> None: assert _is_supported_device(master_tensor) self.master = master_tensor - self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + self._per_device_tensors: dict[torch.device, torch.Tensor] = {} class ShardedGradScaler(GradScaler): @@ -115,7 +116,7 @@ class ShardedGradScaler(GradScaler): ... @overload - def scale(self, outputs: List[torch.Tensor]) -> List[torch.Tensor]: + def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: ... @overload @@ -145,7 +146,7 @@ class ShardedGradScaler(GradScaler): # format (fp16, bf16) and so the scaled loss should be of the same dtype. return scaled_output.type(outputs.dtype) - stash: List[_GeneralMultiDeviceReplicator] = [] + stash: list[_GeneralMultiDeviceReplicator] = [] def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]): if isinstance(val, torch.Tensor): @@ -175,7 +176,7 @@ class ShardedGradScaler(GradScaler): inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool = True, - ) -> Dict[torch.device, torch.Tensor]: + ) -> dict[torch.device, torch.Tensor]: per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale) per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf) diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index bbb909e003f..55d6b3bc58f 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -7,19 +7,8 @@ import contextlib import copy from abc import ABC, abstractmethod -from typing import ( - Any, - Callable, - cast, - Dict, - Generator, - Iterable, - Optional, - Sequence, - Set, - Type, - Union, -) +from collections.abc import Generator, Iterable, Sequence +from typing import Any, Callable, cast, Optional, Union import torch.nn as nn @@ -51,7 +40,7 @@ def _post_order_apply( not changed. """ # Track visited modules to avoid visiting shared modules multiple times - visited_modules: Set[nn.Module] = {root_module} + visited_modules: set[nn.Module] = {root_module} def _post_order_apply_inner( module: nn.Module, @@ -82,7 +71,7 @@ def _post_order_apply( def _construct_wrap_fn( root_module: nn.Module, - target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]], + target_module_to_kwargs: dict[nn.Module, dict[str, Any]], fsdp_fn: Callable, ) -> Callable[[nn.Module], Optional[nn.Module]]: """ @@ -104,10 +93,10 @@ def _construct_wrap_fn( def _run_mixed_precision_override_policy( root_module: nn.Module, - module_classes: Iterable[Type[nn.Module]], - ignored_modules: Set[nn.Module], - root_kwargs: Dict[str, Any], - target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]], + module_classes: Iterable[type[nn.Module]], + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + target_module_to_kwargs: dict[nn.Module, dict[str, Any]], ): module_classes_tuple = tuple(set(module_classes)) for module in root_module.modules(): @@ -141,9 +130,9 @@ class _Policy(ABC): def _run_policy( self, root_module: nn.Module, - ignored_modules: Set[nn.Module], - root_kwargs: Dict[str, Any], - ) -> Dict[nn.Module, Dict[str, Any]]: + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + ) -> dict[nn.Module, dict[str, Any]]: """ This should return a dict ``target_module_to_kwargs`` that maps from each target module to wrap to its kwargs. @@ -155,7 +144,7 @@ def _module_wrap_policy( module: nn.Module, recurse: bool, nonwrapped_numel: int, - module_classes: Set[Type[nn.Module]], + module_classes: set[type[nn.Module]], ) -> bool: """ This auto wrap policy wraps every module that is an instance of any type in @@ -189,7 +178,7 @@ class ModuleWrapPolicy(_Policy): passing in the kwargs given to the root. """ - def __init__(self, module_classes: Iterable[Type[nn.Module]]): + def __init__(self, module_classes: Iterable[type[nn.Module]]): module_classes_set = set(module_classes) self._module_classes = module_classes_set self._module_classes_str = str(module_classes_set) @@ -197,11 +186,11 @@ class ModuleWrapPolicy(_Policy): def _run_policy( self, root_module: nn.Module, - ignored_modules: Set[nn.Module], - root_kwargs: Dict[str, Any], - ) -> Dict[nn.Module, Dict[str, Any]]: + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + ) -> dict[nn.Module, dict[str, Any]]: module_classes = tuple(self._module_classes) - target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {} + target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {} for module in root_module.modules(): if module in ignored_modules: continue @@ -245,16 +234,16 @@ class CustomPolicy(_Policy): >>> fsdp_model = FSDP(model, auto_wrap_policy=policy) """ - def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]): + def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, dict[str, Any]]]): self._lambda_fn = lambda_fn def _run_policy( self, root_module: nn.Module, - ignored_modules: Set[nn.Module], - root_kwargs: Dict[str, Any], - ) -> Dict[nn.Module, Dict[str, Any]]: - target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {} + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + ) -> dict[nn.Module, dict[str, Any]]: + target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {} for module in root_module.modules(): if module in ignored_modules: continue @@ -307,7 +296,7 @@ def transformer_auto_wrap_policy( module: nn.Module, recurse: bool, nonwrapped_numel: int, - transformer_layer_cls: Set[Type[nn.Module]], + transformer_layer_cls: set[type[nn.Module]], ) -> bool: """ See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the @@ -352,8 +341,8 @@ def size_based_auto_wrap_policy( nonwrapped_numel: int, # Additional custom arguments min_num_params: int = int(1e8), - force_leaf_modules: Optional[Set[Type[nn.Module]]] = None, - exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None, + force_leaf_modules: Optional[set[type[nn.Module]]] = None, + exclude_wrap_modules: Optional[set[type[nn.Module]]] = None, ) -> bool: """ A size-based auto wrap policy. @@ -495,8 +484,8 @@ def _recursive_wrap( module: nn.Module, auto_wrap_policy: Callable, wrapper_cls: Callable, - ignored_modules: Set[nn.Module], - ignored_params: Set[nn.Parameter], + ignored_modules: set[nn.Module], + ignored_params: set[nn.Parameter], only_wrap_children: bool = False, **kwargs: Any, ) -> tuple[nn.Module, int]: @@ -573,9 +562,9 @@ class _ConfigAutoWrap: in_autowrap_context: bool = False # Context flag wrapper_cls: Optional[Callable] = None # The wrapper class - kwargs: Dict[str, Any] = {} # Wrapper's args + kwargs: dict[str, Any] = {} # Wrapper's args - def __init__(self, **kwargs: Dict[str, Any]): + def __init__(self, **kwargs: dict[str, Any]): self.kwargs = kwargs @staticmethod