diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 181fab713fe..a4dd3459783 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -79,7 +79,7 @@ if is_available(): finally: sys.stdin = _stdin - _breakpoint_cache: typing.Dict[int, typing.Any] = {} + _breakpoint_cache: dict[int, typing.Any] = {} def breakpoint(rank: int = 0, skip: int = 0): """ diff --git a/torch/distributed/_checkpointable.py b/torch/distributed/_checkpointable.py index 359790c3950..bc0a288f129 100644 --- a/torch/distributed/_checkpointable.py +++ b/torch/distributed/_checkpointable.py @@ -1,5 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -from typing import List, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable import torch @@ -12,7 +12,7 @@ class _Checkpointable(Protocol): # noqa: PYI046 This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface. """ - def __create_write_items__(self, fqn: str, object: object) -> List[object]: + def __create_write_items__(self, fqn: str, object: object) -> list[object]: """ Return a list of WriteItems based on object's contents. """ @@ -20,7 +20,7 @@ class _Checkpointable(Protocol): # noqa: PYI046 "_Checkpointable._create_write_items is not implemented" ) - def __create_chunk_list__(self) -> List[object]: + def __create_chunk_list__(self) -> list[object]: """ Return a list of `ChunkStorageMetadata` based on object's contents. """ diff --git a/torch/distributed/_composable/checkpoint_activation.py b/torch/distributed/_composable/checkpoint_activation.py index a1d15d1128c..df801543af5 100644 --- a/torch/distributed/_composable/checkpoint_activation.py +++ b/torch/distributed/_composable/checkpoint_activation.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs +from collections.abc import Generator from contextlib import contextmanager, nullcontext -from typing import Any, ContextManager, Dict, Generator, Optional +from typing import Any, ContextManager, Optional import torch import torch.nn as nn @@ -85,7 +86,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module: ) def forward_pre_hook( - module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any] + module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> None: if checkpoint.state(module).enable_hook: diff --git a/torch/distributed/_composable/contract.py b/torch/distributed/_composable/contract.py index 5954d317f1a..cb594c7be1b 100644 --- a/torch/distributed/_composable/contract.py +++ b/torch/distributed/_composable/contract.py @@ -1,20 +1,9 @@ # mypy: allow-untyped-defs import uuid from collections import OrderedDict +from collections.abc import Sequence from functools import wraps -from typing import ( - Callable, - Dict, - Generic, - List, - Optional, - overload, - Protocol, - Sequence, - Type, - TypeVar, - Union, -) +from typing import Callable, Generic, Optional, overload, Protocol, TypeVar, Union from typing_extensions import Concatenate, ParamSpec import torch @@ -66,7 +55,7 @@ def contract() -> ( @overload def contract( - state_cls: Type[_TState], + state_cls: type[_TState], ) -> Callable[ [Callable[Concatenate[nn.Module, _P], Optional[nn.Module]]], _ContractFn[Concatenate[nn.Module, _P], _T, _TState], @@ -75,7 +64,7 @@ def contract( def contract( - state_cls: Type = _State, + state_cls: type = _State, ) -> Callable[ [ Callable[ @@ -153,21 +142,21 @@ def contract( # `func` is allowed to return different module instances than the # input modules as long as FQNs are preserved following the input # module order - all_orig_named_params: List[Dict[str, nn.Parameter]] = [] - all_orig_named_buffers: List[Dict[str, torch.Tensor]] = [] - all_orig_named_modules: List[Dict[str, nn.Module]] = [] + all_orig_named_params: list[dict[str, nn.Parameter]] = [] + all_orig_named_buffers: list[dict[str, torch.Tensor]] = [] + all_orig_named_modules: list[dict[str, nn.Module]] = [] for module in modules: - default_all_state: Dict[Callable, _State] = OrderedDict() - default_registry: Dict[str, RegistryItem] = OrderedDict() - all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload] + default_all_state: dict[Callable, _State] = OrderedDict() + default_registry: dict[str, RegistryItem] = OrderedDict() + all_state: dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload] STATE_KEY, default_all_state ) if not isinstance(all_state, dict): raise AssertionError( f"Distributed composable API states corrupted: {all_state}" ) - registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload] + registry: dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload] REGISTRY_KEY, default_registry ) if not isinstance(registry, dict): @@ -195,9 +184,9 @@ def contract( else: updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type] - all_new_named_params: List[Dict[str, nn.Parameter]] = [] - all_new_named_buffers: List[Dict[str, torch.Tensor]] = [] - all_new_named_modules: List[Dict[str, nn.Module]] = [] + all_new_named_params: list[dict[str, nn.Parameter]] = [] + all_new_named_buffers: list[dict[str, torch.Tensor]] = [] + all_new_named_modules: list[dict[str, nn.Module]] = [] for module in updated_modules: all_new_named_params.append(OrderedDict(module.named_parameters())) all_new_named_buffers.append(OrderedDict(module.named_buffers())) @@ -212,7 +201,7 @@ def contract( f"Outputs: {num_new_modules} modules" ) - def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str): + def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str): if orig_fqns == new_fqns: return @@ -280,7 +269,7 @@ def contract( return inner -def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]: +def _get_registry(module: nn.Module) -> Optional[dict[str, RegistryItem]]: r""" Get an ``OrderedDict`` of composable APIs that have been applied to the ``module``, indexed by the API name. If no API has been applied, then this diff --git a/torch/distributed/_composable/replicate.py b/torch/distributed/_composable/replicate.py index d31e11c6445..cb3d916d646 100644 --- a/torch/distributed/_composable/replicate.py +++ b/torch/distributed/_composable/replicate.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import weakref -from typing import Any, Dict, Iterable, List, NoReturn, Optional, Set +from collections.abc import Iterable +from typing import Any, NoReturn, Optional import torch import torch.nn as nn @@ -24,17 +25,17 @@ class _ReplicateState(_State): # TODO(@fegin): this variable is originally create for testing, we # should remove this if possible. self._orig_module = self.module - self._param_names: List[str] = [] + self._param_names: list[str] = [] self._no_sync: bool = False self._init_args: Optional[tuple[Any, ...]] = None - self._init_kwargs: Dict[str, Any] = {} - self._comm_hook_args: List[Any] = [] + self._init_kwargs: dict[str, Any] = {} + self._comm_hook_args: list[Any] = [] def _collect_params( self, module: nn.Module, - ignored_modules: Set[nn.Module], - ignored_params: Set[nn.Parameter], + ignored_modules: set[nn.Module], + ignored_params: set[nn.Parameter], prefix: str = _ROOT_MODULE_PREFIX, ) -> None: # skip if managed by fully_sharded API @@ -76,7 +77,7 @@ class _ReplicateState(_State): def init( self, module: nn.Module, - ignored_modules: Set[nn.Module], + ignored_modules: set[nn.Module], **kwargs, ) -> None: if self.has_initialized: @@ -125,7 +126,7 @@ class _ReplicateState(_State): self._init_kwargs = kwargs def forward_pre_hook( - self, module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any] + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> Any: if self._init_args or self._init_kwargs: self.lazy_init() diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 0347d7c9cf3..67c4adb65b0 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -2,7 +2,7 @@ import contextlib import sys import warnings -from typing import Any, cast, List, Optional, Type, TYPE_CHECKING, Union +from typing import Any, cast, List, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist @@ -94,8 +94,8 @@ Functional collectives can accept any of these types to describe the ranks parti The different types will be desugared to a canonical format """ RANK_TYPES = Union[ - List[int], - List[List[int]], + list[int], + list[list[int]], dist.ProcessGroup, DeviceMesh, tuple["dist.tensor.DeviceMesh", int], @@ -331,8 +331,8 @@ def reduce_scatter_tensor_autograd( def all_reduce_coalesced( - self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = "" -) -> List[torch.Tensor]: + self: list[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = "" +) -> list[torch.Tensor]: """ Reduces a list of tensors across all machines in such a way that all get the final result. @@ -359,8 +359,8 @@ def all_reduce_coalesced( def all_gather_into_tensor_coalesced( - self: List[torch.Tensor], group: RANK_TYPES, tag: str = "" -) -> List[torch.Tensor]: + self: list[torch.Tensor], group: RANK_TYPES, tag: str = "" +) -> list[torch.Tensor]: """ Gather a list of tensors across from all machines. @@ -388,12 +388,12 @@ def all_gather_into_tensor_coalesced( def reduce_scatter_tensor_coalesced( - inputs: List[torch.Tensor], + inputs: list[torch.Tensor], reduceOp: str, - scatter_dim: List[int], + scatter_dim: list[int], group: RANK_TYPES, tag: str = "", -) -> List[torch.Tensor]: +) -> list[torch.Tensor]: """ Reduces a list of tensors across all machines in such a way that all get the final result, then scatter the results to corresponding ranks. @@ -450,8 +450,8 @@ def _is_view_op(tgt): def all_to_all_single( self: torch.Tensor, - output_split_sizes: Optional[List[int]], - input_split_sizes: Optional[List[int]], + output_split_sizes: Optional[list[int]], + input_split_sizes: Optional[list[int]], group: RANK_TYPES, tag: str = "", ) -> torch.Tensor: @@ -498,8 +498,8 @@ def all_to_all_single( def all_to_all_single_autograd( self: torch.Tensor, - output_split_sizes: Optional[List[int]], - input_split_sizes: Optional[List[int]], + output_split_sizes: Optional[list[int]], + input_split_sizes: Optional[list[int]], group: RANK_TYPES, tag: str = "", ) -> torch.Tensor: @@ -535,7 +535,7 @@ def all_to_all_single_autograd( def permute_tensor( self: torch.Tensor, - src_dst: List[int], + src_dst: list[int], group: RANK_TYPES, tag: str = "", ) -> torch.Tensor: @@ -608,7 +608,7 @@ class AsyncCollectiveTensor(torch.Tensor): return AsyncCollectiveTensor(elem) def __coerce_same_metadata_as_tangent__( - self, expected_metadata: Any, expected_type: Optional[Type] = None + self, expected_metadata: Any, expected_type: Optional[type] = None ): if expected_type is not torch.Tensor: return None @@ -677,7 +677,7 @@ Utils and infrastructure for tracing support """ -def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, List[int], int]: +def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int]: """ _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable. @@ -690,10 +690,10 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, List[int], int if TYPE_CHECKING: def cast_listlistint(x): - return cast(List[List[int]], x) + return cast(list[list[int]], x) def cast_listint(x): - return cast(List[int], x) + return cast(list[int], x) else: # fake cast op for use at runtime since dynamo doesn't support real cast @@ -705,7 +705,7 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, List[int], int def cast_listint(x): return x - rankset: List[int] + rankset: list[int] if isinstance(group, list): if isinstance(group[0], list): nested_list = cast_listlistint(group) @@ -1143,7 +1143,7 @@ def all_to_all_inplace( def all_gather_inplace( - tensor_list: List[torch.Tensor], + tensor_list: list[torch.Tensor], tensor: torch.Tensor, group=None, async_op=False, diff --git a/torch/distributed/_functional_collectives_impl.py b/torch/distributed/_functional_collectives_impl.py index 4bd193d662b..0c1ac0a079d 100644 --- a/torch/distributed/_functional_collectives_impl.py +++ b/torch/distributed/_functional_collectives_impl.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import List, Optional +from typing import Optional import torch import torch.distributed.distributed_c10d as c10d @@ -60,7 +60,7 @@ def _reduce_scatter_tensor( input: torch.Tensor, reduce_op: str, tag: str, - ranks: List[int], + ranks: list[int], group_size: int, ): group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) @@ -73,10 +73,10 @@ def _reduce_scatter_tensor( def _reduce_scatter_tensor_coalesced( - inputs: List[torch.Tensor], + inputs: list[torch.Tensor], reduce_op: str, tag: str, - ranks: List[int], + ranks: list[int], group_size: int, ): group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) @@ -90,10 +90,10 @@ def _reduce_scatter_tensor_coalesced( def _all_to_all_single( input: torch.Tensor, - output_split_sizes: Optional[List[int]], - input_split_sizes: Optional[List[int]], + output_split_sizes: Optional[list[int]], + input_split_sizes: Optional[list[int]], tag: str, - ranks: List[int], + ranks: list[int], group_size: int, ): if output_split_sizes is None or input_split_sizes is None: diff --git a/torch/distributed/_shard/_utils.py b/torch/distributed/_shard/_utils.py index d06fc4dc961..6fd641b3f94 100644 --- a/torch/distributed/_shard/_utils.py +++ b/torch/distributed/_shard/_utils.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence import torch from torch.distributed._shard.metadata import ShardMetadata diff --git a/torch/distributed/_shard/metadata.py b/torch/distributed/_shard/metadata.py index 2611d13ef3a..1dce5b44df2 100644 --- a/torch/distributed/_shard/metadata.py +++ b/torch/distributed/_shard/metadata.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs from dataclasses import dataclass from functools import reduce -from typing import List, Optional, Union +from typing import Optional, Union from torch.distributed.remote_device import _remote_device @@ -25,14 +25,14 @@ class ShardMetadata: __slots__ = ["shard_offsets", "shard_sizes", "placement"] - shard_offsets: List[int] - shard_sizes: List[int] + shard_offsets: list[int] + shard_sizes: list[int] placement: Optional[_remote_device] def __init__( self, - shard_offsets: List[int], - shard_sizes: List[int], + shard_offsets: list[int], + shard_sizes: list[int], placement: Optional[Union[str, _remote_device]] = None, ): self.shard_offsets = shard_offsets diff --git a/torch/distributed/_shard/sharded_optim/__init__.py b/torch/distributed/_shard/sharded_optim/__init__.py index 8b9db18ef1f..7deab8d253d 100644 --- a/torch/distributed/_shard/sharded_optim/__init__.py +++ b/torch/distributed/_shard/sharded_optim/__init__.py @@ -1,4 +1,5 @@ -from typing import Iterator, Tuple, Union +from collections.abc import Iterator +from typing import Tuple, Union import torch.nn as nn from torch.distributed._shard.sharded_tensor import ShardedTensor diff --git a/torch/distributed/_shard/sharded_optim/api.py b/torch/distributed/_shard/sharded_optim/api.py index c0ddbece766..8c899437346 100644 --- a/torch/distributed/_shard/sharded_optim/api.py +++ b/torch/distributed/_shard/sharded_optim/api.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs -from typing import Any, Dict, List, Mapping, Union +from collections.abc import Mapping +from typing import Any, Union import torch.optim as optim from torch import Tensor @@ -28,7 +29,7 @@ class ShardedOptimizer(optim.Optimizer): **optimizer_kwargs: the key-word arguments to initialize the optimizer. """ - tensors: List[Tensor] = [] + tensors: list[Tensor] = [] for value in named_params.values(): if isinstance(value, ShardedTensor): tensors.extend( @@ -72,7 +73,7 @@ class ShardedOptimizer(optim.Optimizer): """ self._optim.step(closure) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """ Returned state and param_groups will contain parameter keys instead of parameter indices like torch.optim.Optimizer. diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index db7090820ea..881193cf0ce 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -356,7 +356,7 @@ def randn( def init_from_local_shards( - local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False + local_shards: list[Shard], *global_size, process_group=None, init_rrefs=False ) -> ShardedTensor: """ Creates an :class:`ShardedTensor` from local shards and the global metadata. diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index ef7be1f43cc..f816fb2af0d 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -8,7 +8,7 @@ import warnings import weakref from dataclasses import dataclass from functools import reduce -from typing import Callable, cast, Dict, List, Optional, Sequence, TYPE_CHECKING +from typing import Callable, cast, Optional, TYPE_CHECKING from typing_extensions import deprecated import torch @@ -41,23 +41,25 @@ from .utils import ( if TYPE_CHECKING: + from collections.abc import Sequence + from torch.distributed._shard.metadata import ShardMetadata # Tracking for sharded tensor objects. _sharded_tensor_lock = threading.Lock() _sharded_tensor_current_id = 0 -_sharded_tensor_map: Dict[int, weakref.ReferenceType[ShardedTensor]] = {} +_sharded_tensor_map: dict[int, weakref.ReferenceType[ShardedTensor]] = {} # Default sharded ops -_SHARDED_OPS: Dict[Callable, Callable] = {} +_SHARDED_OPS: dict[Callable, Callable] = {} # Customized user ops -_CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {} +_CUSTOM_SHARDED_OPS: dict[Callable, Callable] = {} def _register_remote_shards( - sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int + sharded_tensor_id: int, rrefs: list[rpc.RRef[Shard]], rpc_rank: int ): with _sharded_tensor_lock: if sharded_tensor_id not in _sharded_tensor_map: @@ -75,7 +77,7 @@ def _register_remote_shards( class ShardedTensorBase(torch.Tensor): _sharding_spec: shard_spec.ShardingSpec _metadata: ShardedTensorMetadata - _local_shards: List[Shard] + _local_shards: list[Shard] def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): # Use __new__ to construct a wrapper tensor, for recording tensor @@ -125,7 +127,7 @@ class ShardedTensorBase(torch.Tensor): """ return self._metadata - def local_shards(self) -> List[Shard]: + def local_shards(self) -> list[Shard]: """ Returns a list of :class:`Shard' corresponding to the local shards for this rank. Returns an empty list if the current rank @@ -136,7 +138,7 @@ class ShardedTensorBase(torch.Tensor): @classmethod def _init_from_local_shards_and_global_metadata( cls, - local_shards: List[Shard], + local_shards: list[Shard], sharded_tensor_metadata: ShardedTensorMetadata, sharding_spec=None, ) -> ShardedTensorBase: @@ -290,7 +292,7 @@ class ShardedTensor(ShardedTensorBase): self._sharded_tensor_id = None self._process_group = self._normalize_pg(process_group) - self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {} + self._remote_shards: dict[int, list[rpc.RRef[Shard]]] = {} def _post_init(self): # Initialize RPC if available. @@ -351,7 +353,7 @@ class ShardedTensor(ShardedTensorBase): continue if len(self.local_shards()) != 0: - rrefs: List[rpc.RRef[Shard]] = [ + rrefs: list[rpc.RRef[Shard]] = [ rpc.RRef(shard) for shard in self.local_shards() ] fut = rpc.rpc_async( @@ -430,7 +432,7 @@ class ShardedTensor(ShardedTensorBase): world_size = dist.get_world_size(self._process_group) rank_sizes = [0 for _ in range(world_size)] max_rank_size = 0 - shard_placement: Dict[ShardMetadata, tuple[int, int]] = {} + shard_placement: dict[ShardMetadata, tuple[int, int]] = {} # collect sizes for shard_md in self.metadata().shards_metadata: shard_rank = cast(_remote_device, shard_md.placement).rank() @@ -440,7 +442,7 @@ class ShardedTensor(ShardedTensorBase): rank_sizes[shard_rank] += shard_size(shard_md) max_rank_size = max(max_rank_size, rank_sizes[shard_rank]) - gather_list: Optional[List[torch.Tensor]] + gather_list: Optional[list[torch.Tensor]] if rank == dst: assert out is not None if enforce_dtype: @@ -535,7 +537,7 @@ class ShardedTensor(ShardedTensorBase): return self # if not, returns a copy of this object on CPU - list_shards: List[Shard] = [] + list_shards: list[Shard] = [] # move all local shards to cpu, and change metadata for shard in self._local_shards: cpu_tensor = shard.tensor.cpu(memory_format=memory_format) # type: ignore[call-arg] @@ -594,7 +596,7 @@ class ShardedTensor(ShardedTensorBase): current_device = torch.device(torch.cuda.current_device()) # returns a copy of ShardedTensor on CUDA current device - list_shards: List[Shard] = [] + list_shards: list[Shard] = [] # move all local shards to current device, and change metadata # if local shards already on the current device, there's no # real data movement, only the metadata are copied. @@ -683,7 +685,7 @@ class ShardedTensor(ShardedTensorBase): return self # returns a copy of ShardedTensor on CUDA current device - list_shards: List[Shard] = [] + list_shards: list[Shard] = [] for shard in self._local_shards: new_tensor = shard.tensor.to( # type: ignore[call-overload] @@ -726,7 +728,7 @@ class ShardedTensor(ShardedTensorBase): @classmethod def _init_from_local_shards( cls, - local_shards: List[Shard], + local_shards: list[Shard], *global_size, process_group=None, init_rrefs=False, @@ -746,7 +748,7 @@ class ShardedTensor(ShardedTensorBase): # STEP 2. Validate metadata across ranks, and build a global sharded tensor # metadata by gathering local ShardedTensorMetadata - gathered_metadatas: List[Optional[ShardedTensorMetadata]] = [] + gathered_metadatas: list[Optional[ShardedTensorMetadata]] = [] if world_size > 1: gathered_metadatas = [None for _ in range(world_size)] @@ -866,7 +868,7 @@ class ShardedTensor(ShardedTensorBase): process_group = cls._normalize_pg(process_group) current_rank = dist.get_rank() # intentional to get global rank - local_shards: List[Shard] = [] + local_shards: list[Shard] = [] for shard_metadata in sharded_tensor_metadata.shards_metadata: rank, _device = _parse_and_validate_remote_device( process_group, shard_metadata.placement @@ -887,7 +889,7 @@ class ShardedTensor(ShardedTensorBase): @classmethod def _init_from_local_shards_and_global_metadata( # type: ignore[override] cls, - local_shards: List[Shard], + local_shards: list[Shard], sharded_tensor_metadata: ShardedTensorMetadata, process_group=None, init_rrefs=False, @@ -1190,11 +1192,11 @@ class ShardedTensor(ShardedTensorBase): return self._metadata.tensor_properties.pin_memory def _register_remote_shards( - self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int + self, remote_shards: list[rpc.RRef[Shard]], rpc_rank: int ): self._remote_shards[rpc_rank] = remote_shards - def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]: + def remote_shards(self) -> dict[int, list[rpc.RRef[Shard]]]: """ Returns a Dict[int, RRef] with keys being the RPC rank and values being RRefs to shards on that rank. Need to initialize the diff --git a/torch/distributed/_shard/sharded_tensor/logger.py b/torch/distributed/_shard/sharded_tensor/logger.py index 39ee4380703..ff8cb4d18fb 100644 --- a/torch/distributed/_shard/sharded_tensor/logger.py +++ b/torch/distributed/_shard/sharded_tensor/logger.py @@ -7,12 +7,11 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import List from torch.distributed._shard.sharded_tensor.logging_handlers import _log_handlers -__all__: List[str] = [] +__all__: list[str] = [] def _get_or_create_logger() -> logging.Logger: diff --git a/torch/distributed/_shard/sharded_tensor/logging_handlers.py b/torch/distributed/_shard/sharded_tensor/logging_handlers.py index 021ad100f06..ed6832fd1ae 100644 --- a/torch/distributed/_shard/sharded_tensor/logging_handlers.py +++ b/torch/distributed/_shard/sharded_tensor/logging_handlers.py @@ -7,11 +7,10 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Dict, List -__all__: List[str] = [] +__all__: list[str] = [] -_log_handlers: Dict[str, logging.Handler] = { +_log_handlers: dict[str, logging.Handler] = { "default": logging.NullHandler(), } diff --git a/torch/distributed/_shard/sharded_tensor/metadata.py b/torch/distributed/_shard/sharded_tensor/metadata.py index e53ac25fa55..466ca1a0c51 100644 --- a/torch/distributed/_shard/sharded_tensor/metadata.py +++ b/torch/distributed/_shard/sharded_tensor/metadata.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs from dataclasses import dataclass, field from enum import Enum -from typing import List import torch from torch.distributed._shard.metadata import ShardMetadata @@ -87,7 +86,7 @@ class ShardedTensorMetadata: """ # Metadata about each shard of the Tensor - shards_metadata: List[ShardMetadata] = field(default_factory=list) + shards_metadata: list[ShardMetadata] = field(default_factory=list) # Size of each dim of the overall Tensor. size: torch.Size = field(default=torch.Size([])) diff --git a/torch/distributed/_shard/sharded_tensor/reshard.py b/torch/distributed/_shard/sharded_tensor/reshard.py index 858ed273adc..30505721ff9 100644 --- a/torch/distributed/_shard/sharded_tensor/reshard.py +++ b/torch/distributed/_shard/sharded_tensor/reshard.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import copy -from typing import List import torch import torch.distributed as dist @@ -44,7 +43,7 @@ def build_reshard_metadata( st_size: torch.Size, sharding_spec: shard_spec.ShardingSpec, world_size: int, -) -> tuple[List[ShardMetadata], List[int]]: +) -> tuple[list[ShardMetadata], list[int]]: """ Based the given sharding spec, we calculate the offset and local shard size. We then build a ShardMetadata on top of the calculation result. @@ -86,7 +85,7 @@ def reshuffle_local_shard( sharding_spec: shard_spec.ShardingSpec, resharding_spec: shard_spec.ShardingSpec, pg: ProcessGroup, -) -> tuple[List[Shard], List[ShardMetadata]]: +) -> tuple[list[Shard], list[ShardMetadata]]: """ Reshuffle the local shard directly when the reshard dim is same as the original sharding dim. Logically we do this in two step: @@ -155,7 +154,7 @@ def reshard_local_shard( sharding_spec: shard_spec.ShardingSpec, resharding_spec: shard_spec.ShardingSpec, pg: ProcessGroup, -) -> tuple[List[Shard], List[ShardMetadata]]: +) -> tuple[list[Shard], list[ShardMetadata]]: """ Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is different from the original sharding dim, we need to do two steps logically: @@ -198,7 +197,7 @@ def reshard_local_shard( if rearrange_input: # Need to re-arrange reshard_dim of local_tensor before all2all. - indices: List[int] = [] + indices: list[int] = [] for metadata in shards_metadata: offset_start_idx = metadata.shard_offsets[reshard_dim] split_size = metadata.shard_sizes[reshard_dim] diff --git a/torch/distributed/_shard/sharded_tensor/shard.py b/torch/distributed/_shard/sharded_tensor/shard.py index cdacc2d3b20..2d9d4357436 100644 --- a/torch/distributed/_shard/sharded_tensor/shard.py +++ b/torch/distributed/_shard/sharded_tensor/shard.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import List import torch from torch.distributed._shard.metadata import ShardMetadata @@ -43,7 +42,7 @@ class Shard: @classmethod def from_tensor_and_offsets( - cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int + cls, tensor: torch.Tensor, shard_offsets: list[int], rank: int ) -> "Shard": """ Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank. diff --git a/torch/distributed/_shard/sharded_tensor/utils.py b/torch/distributed/_shard/sharded_tensor/utils.py index a6954813f82..bde20dd3e61 100644 --- a/torch/distributed/_shard/sharded_tensor/utils.py +++ b/torch/distributed/_shard/sharded_tensor/utils.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import collections.abc import copy -from typing import List, Optional, Sequence, TYPE_CHECKING +from collections.abc import Sequence +from typing import Optional, TYPE_CHECKING import torch from torch.distributed import distributed_c10d as c10d, rpc @@ -109,13 +110,13 @@ def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True): def build_metadata_from_local_shards( - local_shards: List[Shard], + local_shards: list[Shard], global_size: torch.Size, current_rank: int, pg: c10d.ProcessGroup, ) -> ShardedTensorMetadata: assert len(local_shards) > 0, "must have local shards!" - local_shard_metadatas: List[ShardMetadata] = [] + local_shard_metadatas: list[ShardMetadata] = [] first_shard_dtype = local_shards[0].tensor.dtype first_shard_layout = local_shards[0].tensor.layout diff --git a/torch/distributed/_shard/sharding_plan/api.py b/torch/distributed/_shard/sharding_plan/api.py index 217ef8dab1e..7fc6080031f 100644 --- a/torch/distributed/_shard/sharding_plan/api.py +++ b/torch/distributed/_shard/sharding_plan/api.py @@ -1,6 +1,6 @@ import abc from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Optional, Union import torch.nn as nn from torch.distributed._shard.sharder import Sharder @@ -62,9 +62,9 @@ class ShardingPlan: >>> ) """ - plan: Dict[str, Union[ShardingSpec, Sharder]] - output_plan: Optional[Dict[str, ShardingSpec]] = None - return_local_tensor: Optional[List[str]] = None + plan: dict[str, Union[ShardingSpec, Sharder]] + output_plan: Optional[dict[str, ShardingSpec]] = None + return_local_tensor: Optional[list[str]] = None class ShardingPlanner(abc.ABC): diff --git a/torch/distributed/_shard/sharding_spec/_internals.py b/torch/distributed/_shard/sharding_spec/_internals.py index 69193f5c7d8..bcbacb40917 100644 --- a/torch/distributed/_shard/sharding_spec/_internals.py +++ b/torch/distributed/_shard/sharding_spec/_internals.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import List, Optional +from typing import Optional from torch.distributed._shard.metadata import ShardMetadata @@ -24,7 +24,7 @@ def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetad def _find_nd_overlapping_shards( - shards: List[ShardMetadata], sharded_dims: List[int] + shards: list[ShardMetadata], sharded_dims: list[int] ) -> Optional[tuple[int, int]]: # Each rank has len(sharded_dims) tuples. Each tuple represent the # [begin, end] (inclusive) pair of that dimension. @@ -55,7 +55,7 @@ def _find_nd_overlapping_shards( def _find_1d_overlapping_shards( - shards: List[ShardMetadata], dim: int + shards: list[ShardMetadata], dim: int ) -> Optional[tuple[int, int]]: # (begin, end, index_in_shards). Begin and end are inclusive. intervals = [ @@ -69,7 +69,7 @@ def _find_1d_overlapping_shards( return None -def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]): +def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]): """ Ensures none of the shards overlap with each other. @@ -82,7 +82,7 @@ def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]): if not shards or len(shards) == 1: return - sharded_dims: List[int] = [] + sharded_dims: list[int] = [] for dim in range(len(shards[0].shard_offsets)): for i in range(1, len(shards)): if ( diff --git a/torch/distributed/_shard/sharding_spec/api.py b/torch/distributed/_shard/sharding_spec/api.py index 91355e356c3..b24f28d973a 100644 --- a/torch/distributed/_shard/sharding_spec/api.py +++ b/torch/distributed/_shard/sharding_spec/api.py @@ -3,7 +3,7 @@ import functools import operator from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Callable, Dict, List, TYPE_CHECKING +from typing import Callable, TYPE_CHECKING import torch import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta @@ -95,7 +95,7 @@ class ShardingSpec(ABC): # Ops customized for a particular ShardingSpec. -_CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {} +_CUSTOM_SHARDING_SPEC_OPS: dict[str, dict[Callable, Callable]] = {} def _has_custom_op(sharding_spec, op): @@ -148,7 +148,7 @@ class EnumerableShardingSpec(ShardingSpec): each shard. Note that none of the shards should overlap. """ - shards: List[ShardMetadata] + shards: list[ShardMetadata] def __post_init__(self): if len(self.shards) == 0: diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py index ec277eb84bf..e8eaeabb9f9 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs from dataclasses import dataclass -from typing import cast, List, Optional, TYPE_CHECKING, Union +from typing import cast, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist @@ -53,7 +53,7 @@ class ChunkShardingSpec(ShardingSpec): ShardingDim = Union[int, str] dim: ShardingDim - placements: List[Union[torch.distributed._remote_device, str]] + placements: list[Union[torch.distributed._remote_device, str]] def __post_init__(self): self._verify_dim(self.dim) @@ -134,7 +134,7 @@ class ChunkShardingSpec(ShardingSpec): local_metadata = None tensors_to_scatter = cast( - List[Optional[torch.Tensor]], + list[Optional[torch.Tensor]], [None] * dist.get_world_size(process_group), ) @@ -195,7 +195,7 @@ class ChunkShardingSpec(ShardingSpec): process_group, src_for_scatter ) - tensors_to_scatter_: Optional[List[torch.Tensor]] = None + tensors_to_scatter_: Optional[list[torch.Tensor]] = None if current_rank == src_rank: tensors_to_scatter_ = [] for t in tensors_to_scatter: diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py index 01a148b5a9a..61808d0adf6 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs -from typing import cast, List +from typing import cast import torch import torch.distributed as dist @@ -222,7 +222,7 @@ def _validate_embedding_bag_param(args, kwargs): ) if include_last_offset and offsets is None: raise ValueError('offsets is required for flag "include_last_offset"!') - if include_last_offset and cast(List[int], offsets)[-1] != input.size(0): + if include_last_offset and cast(list[int], offsets)[-1] != input.size(0): raise ValueError( 'offsets need to have the input size in the end when the flag "include_last_offset" is on!' ) diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 932bdc987ea..f2bb8d16479 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -3,19 +3,8 @@ import copy import io import math import weakref -from typing import ( - Any, - Callable, - cast, - Dict, - List, - Mapping, - MutableMapping, - NamedTuple, - Optional, - TYPE_CHECKING, - Union, -) +from collections.abc import Mapping, MutableMapping +from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist @@ -94,7 +83,7 @@ def _iterate_state_dict( ranks_only: tuple[int, ...] = (), type_check: bool = True, non_blocking: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Iterate through the state dict, applying the given functions to each tensor type. Args: @@ -203,14 +192,14 @@ def _iterate_state_dict( def _gather_state_dict( - state_dict: Dict[str, Any], + state_dict: dict[str, Any], *, pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, cpu_offload: bool = False, ranks_only: tuple[int, ...] = (), type_check: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Given a state_dict, this API gathers all the ShardedTensors or DTensors in the state_dict. @@ -291,11 +280,11 @@ def _gather_state_dict( def _offload_state_dict_to_cpu( - state_dict: Dict[str, Any], + state_dict: dict[str, Any], *, ranks_only: tuple[int, ...] = (), type_check: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Given a state_dict, this API offload all the tensors to CPU memory. @@ -331,11 +320,11 @@ def _offload_state_dict_to_cpu( def _copy_state_dict( - state_dict: Dict[str, Any], - copy_state_dict: Dict[str, Any], + state_dict: dict[str, Any], + copy_state_dict: dict[str, Any], non_blocking: bool = False, type_check: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Copies all tensors in a given state dict into a different state_dict with the same structure. Additionally, a copied state dict with the same value references @@ -380,8 +369,8 @@ def _copy_state_dict( def _create_cpu_state_dict( - state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False -) -> Dict[str, Any]: + state_dict: dict[str, Any], pin_memory: bool = False, share_memory: bool = False +) -> dict[str, Any]: """ Given a state_dict, create another state_dict with the same structure and elements. However, all tensors in the returned state_dict are new tensors on CPU. These @@ -449,8 +438,8 @@ def _create_cpu_state_dict( def _check_state_dict_similarity( - state_dict: Dict[str, Any], - compared_state_dict: Dict[str, Any], + state_dict: dict[str, Any], + compared_state_dict: dict[str, Any], ) -> bool: """ Given two state_dicts, check if the structures are the same. And @@ -496,9 +485,9 @@ class _TensorInfo(NamedTuple): def _broadcast_tensors( - full_state_dict: Dict[str, Any], - local_state_dict: Dict[str, Any], - keys: List[str], + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + keys: list[str], device: torch.device, pg: Optional[dist.ProcessGroup] = None, ) -> None: @@ -536,8 +525,8 @@ def _broadcast_tensors( def _distribute_tensors( - local_state_dict: Dict[str, Any], - keys: List[str], + local_state_dict: dict[str, Any], + keys: list[str], device: torch.device, pg: Optional[dist.ProcessGroup] = None, ) -> None: @@ -578,8 +567,8 @@ def _distribute_tensors( def _broadcast_state_dict( - full_state_dict: Dict[str, Any], - local_state_dict: Dict[str, Any], + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], device: torch.device, pg: Optional[dist.ProcessGroup] = None, strict: bool = False, @@ -637,8 +626,8 @@ def _broadcast_state_dict( def _distribute_state_dict( - full_state_dict: Dict[str, Any], - local_state_dict: Dict[str, Any], + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], device: torch.device, pg: Optional[dist.ProcessGroup] = None, ) -> None: @@ -672,8 +661,8 @@ def _distribute_state_dict( # DCP. PATH_ITEM = Union[str, int] OBJ_PATH = tuple[PATH_ITEM, ...] -FLATTEN_MAPPING = Dict[str, OBJ_PATH] -STATE_DICT_TYPE = Dict[str, Any] +FLATTEN_MAPPING = dict[str, OBJ_PATH] +STATE_DICT_TYPE = dict[str, Any] CONTAINER_TYPE = MutableMapping[PATH_ITEM, Any] @@ -731,14 +720,14 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None """Set ``value`` in ``root_dict`` along the ``path`` object path.""" cur_container = cast(CONTAINER_TYPE, root_dict) - def extend_list(lst: List[Any], idx: int) -> None: + def extend_list(lst: list[Any], idx: int) -> None: while len(lst) <= idx: lst.append(None) for i in range(1, len(path)): prev_key = path[i - 1] key = path[i] - def_val: Union[CONTAINER_TYPE, List[Any]] = {} if type(key) == str else [] + def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) == str else [] if isinstance(cur_container, Mapping): cur_container = cast( @@ -752,7 +741,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None key = path[-1] if type(key) == int: - extend_list(cast(List[Any], cur_container), key) + extend_list(cast(list[Any], cur_container), key) cur_container[key] = value diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index c511c536751..798cc48ead7 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -2,11 +2,12 @@ import math import os import socket import uuid +from collections.abc import Generator from contextlib import contextmanager from datetime import timedelta from enum import Enum from functools import partial -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.distributed._functional_collectives as funcol @@ -15,7 +16,7 @@ from torch._C._autograd import DeviceType from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work -_group_name_to_store: Dict[str, c10d.Store] = {} +_group_name_to_store: dict[str, c10d.Store] = {} def enable_symm_mem_for_group(group_name: str) -> None: @@ -95,7 +96,7 @@ def is_symm_mem_enabled_for_group(group_name: str) -> bool: return _is_test_mode or group_name in _group_name_to_store -_group_name_to_workspace_tensor: Dict[str, Optional[torch.Tensor]] = {} +_group_name_to_workspace_tensor: dict[str, Optional[torch.Tensor]] = {} def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory: @@ -139,7 +140,7 @@ def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory: return _SymmetricMemory.rendezvous(tensor) -_backend_streams: Dict[int, torch.cuda.Stream] = {} +_backend_streams: dict[int, torch.cuda.Stream] = {} def _get_backend_stream(priority: int = 0) -> torch.cuda.Stream: @@ -149,9 +150,9 @@ def _get_backend_stream(priority: int = 0) -> torch.cuda.Stream: def _pipelined_multi_all_gather_and_consume( - shard: List[torch.Tensor], - shard_consumer: Callable[[List[torch.Tensor], int], None], - ag_out: List[torch.Tensor], + shard: list[torch.Tensor], + shard_consumer: Callable[[list[torch.Tensor], int], None], + ag_out: list[torch.Tensor], group_name: str, ag_out_needed: bool = True, ) -> None: @@ -195,11 +196,11 @@ def _pipelined_multi_all_gather_and_consume( assert x.shape[0] * group_size == y.shape[0] assert x.shape[1:] == y.shape[1:] - def copy_shard(dst: List[torch.Tensor], src: List[torch.Tensor]) -> None: + def copy_shard(dst: list[torch.Tensor], src: list[torch.Tensor]) -> None: for d, s in zip(dst, src): d.copy_(s) - def get_p2p_bufs(remote_rank: int) -> List[torch.Tensor]: + def get_p2p_bufs(remote_rank: int) -> list[torch.Tensor]: offset_bytes = 0 bufs = [] for x in shard: @@ -216,7 +217,7 @@ def _pipelined_multi_all_gather_and_consume( local_p2p_bufs = get_p2p_bufs(rank) # shards[i] => shard from rank i - shards: List[List[torch.Tensor]] = [[] for _ in range(group_size)] + shards: list[list[torch.Tensor]] = [[] for _ in range(group_size)] for x in ag_out: for i, y in enumerate(x.chunk(group_size)): shards[i].append(y) @@ -311,7 +312,7 @@ def _pipelined_all_gather_and_consume( shard_consumer(shard, src_rank) """ - def adapter(shard: List[torch.Tensor], rank: int) -> None: + def adapter(shard: list[torch.Tensor], rank: int) -> None: shard_consumer(shard[0], rank) _pipelined_multi_all_gather_and_consume( @@ -505,14 +506,14 @@ def _check_and_verify_fp8_all_gather_scale_mode( def _fused_all_gather_matmul_impl( mm_out_op: torch._ops.OpOverload, A_shard: torch.Tensor, - Bs: List[torch.Tensor], + Bs: list[torch.Tensor], A_scale: Optional[torch.Tensor], - kwargs_list: List[Dict[str, Any]], - out_dtypes: List[Optional[torch.dtype]], + kwargs_list: list[dict[str, Any]], + out_dtypes: list[Optional[torch.dtype]], gather_dim: int, group_name: str, return_A: bool, -) -> tuple[Optional[torch.Tensor], List[torch.Tensor]]: +) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]: if A_shard.dim() < 2: raise ValueError("A_shard must be a matrix") for B in Bs: @@ -563,7 +564,7 @@ def _fused_all_gather_matmul_impl( A_scale_shard.shape[1], ) - def row_wise_sharded_consumer(shard: List[torch.Tensor], rank: int) -> None: + def row_wise_sharded_consumer(shard: list[torch.Tensor], rank: int) -> None: for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): mm_out_op( shard[0], @@ -630,12 +631,12 @@ def _fused_all_gather_matmul_impl( @torch.library.impl(lib, "fused_all_gather_matmul", "Meta") def _fused_all_gather_matmul_fallback( A_shard: torch.Tensor, - Bs: List[torch.Tensor], + Bs: list[torch.Tensor], gather_dim: int, group_name: str, *, return_A: bool = True, -) -> tuple[Optional[torch.Tensor], List[torch.Tensor]]: +) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]: group_size = c10d._get_group_size_by_name(group_name) A = torch.ops._c10d_functional.all_gather_into_tensor( A_shard.contiguous(), group_size, group_name @@ -652,12 +653,12 @@ def _fused_all_gather_matmul_fallback( @torch.library.impl(lib, "fused_all_gather_matmul", "CUDA") def _fused_all_gather_matmul( A_shard: torch.Tensor, - Bs: List[torch.Tensor], + Bs: list[torch.Tensor], gather_dim: int, group_name: str, *, return_A: bool = True, -) -> tuple[Optional[torch.Tensor], List[torch.Tensor]]: +) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]: """ Perform the following logic with micro-pipelined computation and communication: @@ -703,7 +704,7 @@ def _fused_all_gather_matmul( def _should_use_fused_all_gather_matmul_native( A_shard: torch.Tensor, - Bs: List[torch.Tensor], + Bs: list[torch.Tensor], gather_dim: int, group_name: str, ) -> bool: @@ -806,9 +807,9 @@ def _should_use_multimem_all_gather_matmul( def _multimem_all_gather_matmul( A_shard: torch.Tensor, - Bs: List[torch.Tensor], + Bs: list[torch.Tensor], group_name: str, -) -> List[torch.Tensor]: +) -> list[torch.Tensor]: group = c10d._resolve_process_group(group_name) A_shape = torch.Size((A_shard.shape[0] * group.size(), *A_shard.shape[1:])) symm_mem = get_symm_mem_workspace( @@ -822,16 +823,16 @@ def _multimem_all_gather_matmul( @torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta") def _fused_all_gather_scaled_matmul_fallback( A_shard: torch.Tensor, - Bs: List[torch.Tensor], + Bs: list[torch.Tensor], A_scale: torch.Tensor, - B_scales: List[torch.Tensor], + B_scales: list[torch.Tensor], gather_dim: int, group_name: str, - biases: List[Optional[torch.Tensor]], - result_scales: List[Optional[torch.Tensor]], - out_dtypes: List[Optional[torch.dtype]], - use_fast_accum: List[bool], -) -> tuple[torch.Tensor, List[torch.Tensor]]: + biases: list[Optional[torch.Tensor]], + result_scales: list[Optional[torch.Tensor]], + out_dtypes: list[Optional[torch.dtype]], + use_fast_accum: list[bool], +) -> tuple[torch.Tensor, list[torch.Tensor]]: out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) group_size = c10d._get_group_size_by_name(group_name) @@ -896,16 +897,16 @@ def _fused_all_gather_scaled_matmul_fallback( @torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA") def _fused_all_gather_scaled_matmul( A_shard: torch.Tensor, - Bs: List[torch.Tensor], + Bs: list[torch.Tensor], A_scale: torch.Tensor, - B_scales: List[torch.Tensor], + B_scales: list[torch.Tensor], gather_dim: int, group_name: str, - biases: List[Optional[torch.Tensor]], - result_scales: List[Optional[torch.Tensor]], - out_dtypes: List[Optional[torch.dtype]], - use_fast_accum: List[bool], -) -> tuple[torch.Tensor, List[torch.Tensor]]: + biases: list[Optional[torch.Tensor]], + result_scales: list[Optional[torch.Tensor]], + out_dtypes: list[Optional[torch.dtype]], + use_fast_accum: list[bool], +) -> tuple[torch.Tensor, list[torch.Tensor]]: """ Perform the following logic with micro-pipelined computation and communication: @@ -976,7 +977,7 @@ def _fused_all_gather_scaled_matmul( def make_contiguous_for_perm( t: torch.Tensor, - perm: List[int], + perm: list[int], ) -> torch.Tensor: """ Restride `t` such that `t.permute(perm)` is contiguous. @@ -1005,7 +1006,7 @@ def _fused_matmul_reduce_scatter_impl( A: torch.Tensor, B: torch.Tensor, A_scale: Optional[torch.Tensor], - kwargs: Dict[str, Any], + kwargs: dict[str, Any], out_dtype: Optional[torch.dtype], reduce_op: str, scatter_dim: int, @@ -1237,8 +1238,8 @@ def restride_A_for_fused_matmul_reduce_scatter( def _maybe_convert_scalar_types_to_dtypes( - scalar_types: List[Any], -) -> List[Optional[torch.dtype]]: + scalar_types: list[Any], +) -> list[Optional[torch.dtype]]: """ When a list of `torch.dtype`s is passed through the dispatcher as `ScalarType[]`, it is converted to a list of scalar type enum values. This @@ -1270,7 +1271,7 @@ def _maybe_convert_scalar_types_to_dtypes( if any(not isinstance(x, (type(None), int)) for x in scalar_types): return scalar_types - dtypes: List[Optional[torch.dtype]] = [] + dtypes: list[Optional[torch.dtype]] = [] for scalar_type in scalar_types: if scalar_type is None: dtypes.append(scalar_type) @@ -1507,7 +1508,8 @@ def _low_contention_reduce_scatter( # ============================================================================= -from typing import Any, overload, Sequence, TYPE_CHECKING, Union +from collections.abc import Sequence +from typing import Any, overload, TYPE_CHECKING, Union from torch.types import _device, _dtype, _int diff --git a/torch/distributed/_tools/fsdp2_mem_tracker.py b/torch/distributed/_tools/fsdp2_mem_tracker.py index e2ed6a29759..c7a67ebee3d 100644 --- a/torch/distributed/_tools/fsdp2_mem_tracker.py +++ b/torch/distributed/_tools/fsdp2_mem_tracker.py @@ -1,7 +1,7 @@ from copy import deepcopy from datetime import timedelta from functools import partial, wraps -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, TypeVar, Union +from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union from typing_extensions import ParamSpec, TypeVarTuple, Unpack import torch @@ -111,9 +111,9 @@ class _FSDPModMemStats: def __init__(self, mod_fqn: str) -> None: self.mod_fqn = mod_fqn - self.local_peak: Dict[torch.device, int] = {} - self.snapshots: Dict[ - _FSDPModState, List[Dict[torch.device, Dict[str, int]]] + self.local_peak: dict[torch.device, int] = {} + self.snapshots: dict[ + _FSDPModState, list[dict[torch.device, dict[str, int]]] ] = {} @@ -169,7 +169,7 @@ class FSDPMemTracker(MemTracker): self._in_fake_mode: bool = False self._fsdp_mod_to_saved_methods: WeakIdKeyDictionary = WeakIdKeyDictionary() self._saved_collectives: _SavedCollectives - self._ref_class: Type[_RefType] = _FSDPRefType + self._ref_class: type[_RefType] = _FSDPRefType def _instrument_fsdp_sharded_params_grads( self, fsdp_param_group: FSDPParamGroup @@ -190,8 +190,8 @@ class FSDPMemTracker(MemTracker): def _fsdp_state_pre_forward( self, fsdp_mod: FSDPModule, - orig_fsdp_state_pre_fw: Callable[_P, tuple[tuple[Unpack[_Ts]], Dict[str, Any]]], - ) -> Callable[_P, tuple[tuple[Unpack[_Ts]], Dict[str, Any]]]: + orig_fsdp_state_pre_fw: Callable[_P, tuple[tuple[Unpack[_Ts]], dict[str, Any]]], + ) -> Callable[_P, tuple[tuple[Unpack[_Ts]], dict[str, Any]]]: # We capture memory snapshots before and after ``FSDPState._pre_forward`` to attribute the `unsharded` params # and `all_gather` buffers. There are three cases: # Case 1: If the module is not in the ``memory_tracking`` dictionary, create a new ``_FSDPModMemStats`` @@ -208,7 +208,7 @@ class FSDPMemTracker(MemTracker): @wraps(orig_fsdp_state_pre_fw) def inner( *args: _P.args, **kwargs: _P.kwargs - ) -> tuple[tuple[Unpack[_Ts]], Dict[str, Any]]: + ) -> tuple[tuple[Unpack[_Ts]], dict[str, Any]]: mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod) assert mod_fqn is not None if fsdp_mod not in self.memory_tracking: @@ -538,7 +538,7 @@ class FSDPMemTracker(MemTracker): def barrier( group: Union[ProcessGroup, None] = dist.GroupMember.WORLD, async_op: bool = False, - device_ids: Union[List[int], None] = None, + device_ids: Union[list[int], None] = None, ) -> Union[Work, None]: if self._in_fake_mode: return None diff --git a/torch/distributed/_tools/ilp_utils.py b/torch/distributed/_tools/ilp_utils.py index ac5c73ec8e4..b3c2980dd3b 100644 --- a/torch/distributed/_tools/ilp_utils.py +++ b/torch/distributed/_tools/ilp_utils.py @@ -1,5 +1,6 @@ import copy -from typing import cast, Dict, List, OrderedDict, TypedDict +from collections import OrderedDict +from typing import cast, TypedDict import numpy as np @@ -15,10 +16,10 @@ from torch.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStat class ModOrder(TypedDict): - fw_pre_order: List[str] - bw_pre_order: List[str] - fw_post_order: List[str] - bw_post_order: List[str] + fw_pre_order: list[str] + bw_pre_order: list[str] + fw_post_order: list[str] + bw_post_order: list[str] class ModRuntime(TypedDict): @@ -60,18 +61,18 @@ class ModStats(TypedDict): # Number of piecewise-linear functions used for approximating ac tradeoff curve n_segments: int # Slopes of the of piecewise-linear functions - slopes: List[float] + slopes: list[float] # Intercepts of the of piecewise-linear functions - intercepts: List[float] + intercepts: list[float] # X breakpoints of the of piecewise-linear functions - breakpoints: List[float] + breakpoints: list[float] # Original trade-off curves tradeoff_curve: OrderedDict[float, float] class ModuleInfo(TypedDict): mod_order: ModOrder - mod_stats: List[ModStats] + mod_stats: list[ModStats] def aggregate_stats( @@ -96,12 +97,12 @@ def aggregate_stats( """ # Memory stats - mod_mem_stats: Dict[torch.nn.Module, _ModMemStats] = dict( + mod_mem_stats: dict[torch.nn.Module, _ModMemStats] = dict( copy.deepcopy(mem_tracker.memory_tracking) ) # Runtime stats - mod_runtime_stats: Dict[str, ModRuntime] = { + mod_runtime_stats: dict[str, ModRuntime] = { fqn: {"fw": v["fw"], "bw": v["bw"]} for fqn, v in runtime_estimator.mod_runtimes.items() } @@ -116,7 +117,7 @@ def aggregate_stats( # Selective Activation Checkpointing stats sac_estimator.pwlf_sac_tradeoff_curve() - mod_sac_tradeoff_stats: Dict[str, SACTradeOffStats] = copy.deepcopy( + mod_sac_tradeoff_stats: dict[str, SACTradeOffStats] = copy.deepcopy( sac_estimator.sac_mod_tradeoff_stats ) @@ -192,10 +193,10 @@ class Node(ModStats): class Graph: def __init__(self, n: int) -> None: - self.nodes: List[Node] = [] - self.name2node: Dict[str, Node] = {} + self.nodes: list[Node] = [] + self.name2node: dict[str, Node] = {} self.ad_matrix = np.zeros((n, n)) - self.fw_post_order: List[str] = [] + self.fw_post_order: list[str] = [] def add_node(self, node: Node) -> None: self.nodes.append(node) diff --git a/torch/distributed/_tools/mem_tracker.py b/torch/distributed/_tools/mem_tracker.py index 2e9d13f471e..b72987af6f7 100644 --- a/torch/distributed/_tools/mem_tracker.py +++ b/torch/distributed/_tools/mem_tracker.py @@ -5,7 +5,7 @@ import warnings from copy import deepcopy from enum import auto, Enum from functools import partial, wraps -from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union from typing_extensions import Self import torch @@ -116,8 +116,8 @@ class _ModMemStats: self.buffer_mem: int self.input_mem: int self.output_mem: int - self.local_peak: Dict[torch.device, int] = {} - self.snapshots: Dict[_ModState, List[Dict[torch.device, Dict[str, int]]]] = {} + self.local_peak: dict[torch.device, int] = {} + self.snapshots: dict[_ModState, list[dict[torch.device, dict[str, int]]]] = {} class _WeakRefInfo: @@ -171,7 +171,7 @@ class _WeakRefInfo: return self.mem_consumed @staticmethod - def get_untyped_storages(t: torch.Tensor) -> Set[torch.UntypedStorage]: + def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]: """ Recursively extracts untyped storages from a tensor or its subclasses. @@ -242,7 +242,7 @@ def _rounding_fn(value: int, divisor: int, precision: int) -> Union[float, int]: return value if divisor == 1 else round(value / divisor, precision) -def _print_snapshot(snapshot: Dict[torch.device, Dict[str, int]], units: str) -> None: +def _print_snapshot(snapshot: dict[torch.device, dict[str, int]], units: str) -> None: if len(snapshot) == 0: print("No memory tracked.") return @@ -261,7 +261,7 @@ def _print_snapshot(snapshot: Dict[torch.device, Dict[str, int]], units: str) -> def _print_snapshot_tabular( - snapshot: Dict[torch.device, Dict[str, int]], units: str + snapshot: dict[torch.device, dict[str, int]], units: str ) -> None: if len(snapshot) == 0: print("No memory tracked.") @@ -287,7 +287,7 @@ def _print_snapshot_tabular( def _print_state_snapshots( - snapshots: Dict[_State, List[Dict[torch.device, Dict[str, int]]]], units: str + snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str ) -> None: for state, snapshot_list in snapshots.items(): print(f"{state}") @@ -298,7 +298,7 @@ def _print_state_snapshots( def _print_state_snapshots_tabular( - snapshots: Dict[_State, List[Dict[torch.device, Dict[str, int]]]], units: str + snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str ) -> None: try: from tabulate import tabulate @@ -393,9 +393,9 @@ class MemTracker(TorchDispatchMode): def __init__(self) -> None: self.memory_tracking = WeakIdKeyDictionary() - self._curr_mem_snap: Dict[torch.device, Dict[str, int]] = {} - self._peak_mem: Dict[torch.device, int] = {} - self._peak_mem_snap: Dict[torch.device, Dict[str, int]] = {} + self._curr_mem_snap: dict[torch.device, dict[str, int]] = {} + self._peak_mem: dict[torch.device, int] = {} + self._peak_mem_snap: dict[torch.device, dict[str, int]] = {} self._param_to_grad_hook_handles = WeakIdKeyDictionary() self._optimizer_hook_handles: Optional[ tuple[RemovableHandle, RemovableHandle] @@ -404,7 +404,7 @@ class MemTracker(TorchDispatchMode): self._WINFO = WeakIdKeyDictionary() self._mod_tracker = ModTracker() # This is a general memory tracker which can be used with any ``_RefType`` subclass - self._ref_class: Type[_RefType] = _MemRefType + self._ref_class: type[_RefType] = _MemRefType # Flags to track if we are in the AC region or optimizer step region self._in_opt: bool = False self._in_ac: bool = False @@ -461,7 +461,7 @@ class MemTracker(TorchDispatchMode): t: torch.Tensor, reftype: _RefType, update_existing: bool = False, - ) -> Set[_WeakRefInfo]: + ) -> set[_WeakRefInfo]: sts = _WeakRefInfo.get_untyped_storages(t) winfos = set() for st in sts: @@ -565,7 +565,7 @@ class MemTracker(TorchDispatchMode): def get_tracker_snapshot( self, type: str = "current" - ) -> Dict[torch.device, Dict[str, int]]: + ) -> dict[torch.device, dict[str, int]]: """ Capture a snapshot of the memory usage breakdown per device, based on the specified type. @@ -865,7 +865,7 @@ class MemTracker(TorchDispatchMode): tabulate (bool, optional): Whether to display the snapshot in a tabular format. Defaults to False. """ - def natural_sort_key(s: str) -> List[Union[int, str]]: + def natural_sort_key(s: str) -> list[Union[int, str]]: return [ int(text) if text.isdigit() else text.lower() for text in re.split("([0-9]+)", s) diff --git a/torch/distributed/_tools/memory_tracker.py b/torch/distributed/_tools/memory_tracker.py index e4d8aa6e762..7aea4073b01 100644 --- a/torch/distributed/_tools/memory_tracker.py +++ b/torch/distributed/_tools/memory_tracker.py @@ -2,8 +2,9 @@ import operator import pickle from collections import defaultdict +from collections.abc import Sequence from itertools import chain -from typing import Any, Callable, Dict, List, no_type_check, Sequence, TYPE_CHECKING +from typing import Any, Callable, no_type_check, TYPE_CHECKING import torch import torch.nn as nn @@ -72,12 +73,12 @@ class MemoryTracker: def __init__(self) -> None: torch._C._log_api_usage_once("torch.distributed.memory_tracker") - self._hooks: List[RemovableHandle] = [] - self._operator_names: Dict[str, int] = defaultdict(int) - self.memories_allocated: Dict[int, Dict[str, float]] = defaultdict() - self.memories_active: Dict[int, Dict[str, float]] = defaultdict() - self.memories_reserved: Dict[int, Dict[str, float]] = defaultdict() - self._markers: Dict[str, int] = defaultdict(int) + self._hooks: list[RemovableHandle] = [] + self._operator_names: dict[str, int] = defaultdict(int) + self.memories_allocated: dict[int, dict[str, float]] = defaultdict() + self.memories_active: dict[int, dict[str, float]] = defaultdict() + self.memories_reserved: dict[int, dict[str, float]] = defaultdict() + self._markers: dict[str, int] = defaultdict(int) self._cur_module_name: str = "" self._op_index: int = 0 self._num_cuda_retries: int = 0 @@ -133,7 +134,7 @@ class MemoryTracker: The number of the top operators can be configured. """ - op_diff: Dict[str, float] = defaultdict(float) + op_diff: dict[str, float] = defaultdict(float) op_name, previous_allocated_memory = self.memories_allocated[0] for i in range(1, self._op_index): op_name, current_allocated_memory = self.memories_allocated[i] diff --git a/torch/distributed/_tools/mod_tracker.py b/torch/distributed/_tools/mod_tracker.py index 3525ae3f95b..6c4aabbb6d1 100644 --- a/torch/distributed/_tools/mod_tracker.py +++ b/torch/distributed/_tools/mod_tracker.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import warnings import weakref -from typing import Callable, Optional, Set +from typing import Callable, Optional import torch from torch.autograd.graph import register_multi_grad_hook @@ -48,7 +48,7 @@ class ModTracker: """ - parents: Set[str] + parents: set[str] """ A Set containing the fqn for each module currently running their forward """ diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py index eeb051d689b..37ac5944d52 100644 --- a/torch/distributed/_tools/runtime_estimator.py +++ b/torch/distributed/_tools/runtime_estimator.py @@ -2,7 +2,7 @@ import math import os from collections import defaultdict -from typing import Any, Callable, Dict, List, Set +from typing import Any, Callable from typing_extensions import Self import torch @@ -121,13 +121,13 @@ class RuntimeEstimator(TorchDispatchMode): runtime_estimator.display_modulewise_stats() """ - _float_types: Set[torch.dtype] = { + _float_types: set[torch.dtype] = { torch.float16, torch.bfloat16, torch.float32, torch.float64, } - _no_fallback_kernel: Set[torch._ops._OpNamespace] = set() + _no_fallback_kernel: set[torch._ops._OpNamespace] = set() fake_mode: FakeTensorMode def __init__(self) -> None: @@ -135,13 +135,13 @@ class RuntimeEstimator(TorchDispatchMode): self._estimate: Callable self._estimate_mode_type: str self._mod_tracker = ModTracker() - self.mod_runtimes: Dict[str, Dict[str, float]] = defaultdict( + self.mod_runtimes: dict[str, dict[str, float]] = defaultdict( lambda: defaultdict(lambda: 0.0) ) - self.mod_fw_pre_order: List[str] = [] - self.mod_bw_pre_order: List[str] = [] - self.mod_fw_post_order: List[str] = [] - self.mod_bw_post_order: List[str] = [] + self.mod_fw_pre_order: list[str] = [] + self.mod_bw_pre_order: list[str] = [] + self.mod_fw_post_order: list[str] = [] + self.mod_bw_post_order: list[str] = [] self.total_runtime: float = 0.0 # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950 diff --git a/torch/distributed/_tools/sac_estimator.py b/torch/distributed/_tools/sac_estimator.py index 310f3807f9c..aae73d67d89 100644 --- a/torch/distributed/_tools/sac_estimator.py +++ b/torch/distributed/_tools/sac_estimator.py @@ -4,7 +4,7 @@ import sys import warnings from collections import OrderedDict from dataclasses import astuple, dataclass -from typing import Any, Dict, List, NamedTuple, Optional, Set +from typing import Any, NamedTuple, Optional from typing_extensions import Self import torch @@ -42,7 +42,7 @@ _PYTORCH_MIN_ALLOCATE = ( ) -def _get_untyped_storages(t: torch.Tensor) -> Set[torch.UntypedStorage]: +def _get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]: """ Retrieves untyped storages from a `torch.Tensor` or one of its traceable wrapper-subclass. @@ -74,7 +74,7 @@ def _get_untyped_storages(t: torch.Tensor) -> Set[torch.UntypedStorage]: return flattened_tensor_storages -def _display_stats_tabular(headers: List[str], table_data: List[List[Any]]) -> None: +def _display_stats_tabular(headers: list[str], table_data: list[list[Any]]) -> None: try: from tabulate import tabulate except ImportError as err: @@ -125,7 +125,7 @@ class _SACModMetadata: start_idx: int force_store_random: bool - sac_metadata: List[_SACMetadata] + sac_metadata: list[_SACMetadata] @dataclass @@ -144,13 +144,13 @@ class SACStats: force_store_random (bool): Whether to force store random operator results. """ - func_names: List[str] - runtimes: List[float] - memory: List[int] - view_like_ops: List[int] - rand_ops: List[int] - saved_autograd_ops: List[int] - inplace_ops: List[tuple[int, int]] + func_names: list[str] + runtimes: list[float] + memory: list[int] + view_like_ops: list[int] + rand_ops: list[int] + saved_autograd_ops: list[int] + inplace_ops: list[tuple[int, int]] force_store_random: bool @@ -166,7 +166,7 @@ class MSPS(NamedTuple): msps (float): Memory per second calculated as memory/runtime. """ - func_names: Set[str] + func_names: set[str] op_idx: int memory: int runtime: float @@ -189,9 +189,9 @@ class SACTradeOffStats: """ n_segments: int - slopes: List[float] - intercepts: List[float] - fit_breaks: List[float] + slopes: list[float] + intercepts: list[float] + fit_breaks: list[float] tradeoff_curve: OrderedDict[float, float] sac_memory: int sac_runtime: float @@ -210,11 +210,11 @@ class SACGreedyOrderMeta: msps_meta (List[MSPS]): List of Memory and Runtime Statistics for operators. """ - recomputed_ops: Set[int] - stored_ops: Set[int] - inplace_op_groups: Dict[int, Set[int]] - random_ops_group: Dict[int, Set[int]] - msps_meta: List[MSPS] + recomputed_ops: set[int] + stored_ops: set[int] + inplace_op_groups: dict[int, set[int]] + random_ops_group: dict[int, set[int]] + msps_meta: list[MSPS] class SACEstimator(TorchDispatchMode): @@ -251,17 +251,17 @@ class SACEstimator(TorchDispatchMode): """ def __init__(self) -> None: - self.sac_mod_stats: Dict[str, SACStats] = {} - self.sac_mod_tradeoff_stats: Dict[str, SACTradeOffStats] = {} - self.sac_mod_greedy_order_meta: Dict[str, SACGreedyOrderMeta] = {} + self.sac_mod_stats: dict[str, SACStats] = {} + self.sac_mod_tradeoff_stats: dict[str, SACTradeOffStats] = {} + self.sac_mod_greedy_order_meta: dict[str, SACGreedyOrderMeta] = {} self._mod_tracker = ModTracker() - self._sac_metadata: List[_SACMetadata] = [] - self._sac_mod_metadata: Dict[str, _SACModMetadata] = {} - self._leaf_modules: Set[str] = set() + self._sac_metadata: list[_SACMetadata] = [] + self._sac_mod_metadata: dict[str, _SACModMetadata] = {} + self._leaf_modules: set[str] = set() self._saved_tensor_hook_ctx = torch.autograd.graph.saved_tensors_hooks( self._pack_hook, lambda x: x ) - self._saved_tensor_ids: Set[int] = set() + self._saved_tensor_ids: set[int] = set() self._estimate_runtime = RuntimeEstimator._roofline_estimate def _pack_hook(self, x: torch.Tensor) -> torch.Tensor: @@ -313,7 +313,7 @@ class SACEstimator(TorchDispatchMode): return all(not isinstance(x, torch.Tensor) for x in flat_inputs) def _get_sac_stats( - self, data: List[_SACMetadata], force_store_random: bool + self, data: list[_SACMetadata], force_store_random: bool ) -> SACStats: # 1. Ignore the operations that should be skipped by SAC such as aten.detach.default because autograd # inserts those during backward and it breaks the fwd-bwd alignment @@ -381,12 +381,12 @@ class SACEstimator(TorchDispatchMode): ) def _get_inplace_metadata( - self, func: Any, out_storages: Set[UntypedStorage] - ) -> tuple[int, tuple[int, ...], Dict[str, tuple[int, ...]]]: + self, func: Any, out_storages: set[UntypedStorage] + ) -> tuple[int, tuple[int, ...], dict[str, tuple[int, ...]]]: # 1. Get the current index of the metadata obtained so far curr_idx = len(self._sac_metadata) # 2. Get the set of active modules that are not leaf - active_mod_fqns: Set[str] = { + active_mod_fqns: set[str] = { par for par in self._mod_tracker.parents if par not in self._leaf_modules } # 3. Output ids are the identifies of the storage objects corresponding to the tensors @@ -397,7 +397,7 @@ class SACEstimator(TorchDispatchMode): op_idx = curr_idx # 5. Initialize the parent op ids of the inplace op for each of the active modules - mod_op_parent_idxs: Dict[str, int] = { + mod_op_parent_idxs: dict[str, int] = { mod_fqn: -1 for mod_fqn in active_mod_fqns } for i, d in enumerate(self._sac_metadata): @@ -430,9 +430,9 @@ class SACEstimator(TorchDispatchMode): # 1. Get the runtime estimate out, op_time = self._estimate_runtime(func, args, kwargs) flat_outs, _ = tree_flatten(out) - out_storages_cuda: Set[UntypedStorage] = set() - out_storages_cpu: Set[UntypedStorage] = set() - cuda_devices: Set[torch.device] = set() + out_storages_cuda: set[UntypedStorage] = set() + out_storages_cpu: set[UntypedStorage] = set() + cuda_devices: set[torch.device] = set() for o in flat_outs: if isinstance(o, torch.Tensor): if o.device.type == "cuda": @@ -496,8 +496,8 @@ class SACEstimator(TorchDispatchMode): # 1. inplace_op_groups: A dictionary from the top-most parent of inplace-ops to the inplace-ops in the group # The top-most op can itself be an inplace-op or can be a non-inplace op. # 2. inplace_op_to_group_head: A dictionary that maps all the inplace-ops to their respective group heads. - inplace_op_groups: Dict[int, Set[int]] = {} - inplace_op_to_group_head: Dict[int, int] = dict(sac_stats.inplace_ops) + inplace_op_groups: dict[int, set[int]] = {} + inplace_op_to_group_head: dict[int, int] = dict(sac_stats.inplace_ops) # Initialize inplace_op_groups using inplace_op_to_group_head for op_idx, group_head_idx in inplace_op_to_group_head.items(): @@ -508,7 +508,7 @@ class SACEstimator(TorchDispatchMode): # as a group. This is because, they affect the ranom seed generator. If force_store_random is set True, # all of the random ops will be stored by default. For easy of manageability, we store the top-most random op # as the leader of the random_ops_group. - random_ops_group: Dict[int, Set[int]] = {} + random_ops_group: dict[int, set[int]] = {} random_group_head_idx = min(sac_stats.rand_ops, default=-1) has_rand_ops = bool(sac_stats.rand_ops) if has_rand_ops: @@ -521,8 +521,8 @@ class SACEstimator(TorchDispatchMode): # b) If any op in the group is random and force_store_random is set, then entire group will be stored. # c) If none of ops in the group are random and the head of the group is not an in-place op, then # this group can be considered for recomputation in its entireity - stored_ops: Set[int] = set() - recomputed_ops: Set[int] = set() + stored_ops: set[int] = set() + recomputed_ops: set[int] = set() # Case 1: if has_rand_ops and sac_stats.force_store_random: stored_ops.add(random_group_head_idx) @@ -541,7 +541,7 @@ class SACEstimator(TorchDispatchMode): stored_ops.add(group_head_idx) # The potential recompute candidates are populated as: - recompute_candidates: Set[int] = set() + recompute_candidates: set[int] = set() # 1) The random group head if it is not stored if has_rand_ops and random_group_head_idx not in stored_ops: recompute_candidates.add(random_group_head_idx) @@ -557,7 +557,7 @@ class SACEstimator(TorchDispatchMode): ) # We define msps for a recomp candidate as the ratio of memory/runtime aka memory savings per second - msps_meta: List[MSPS] = [] + msps_meta: list[MSPS] = [] for cand_idx in recompute_candidates: op_indices = {cand_idx} if cand_idx in inplace_op_groups: @@ -598,7 +598,7 @@ class SACEstimator(TorchDispatchMode): greedy_order_meta.msps_meta, ) # 1. Intitialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops - recomp_indices: Set[int] = set() + recomp_indices: set[int] = set() for r_idx in recomputed_ops: recomp_indices.add(r_idx) if r_idx in inplace_op_groups: @@ -628,7 +628,7 @@ class SACEstimator(TorchDispatchMode): recomp_runtime / sac_runtime ) # 6. Finally, we add the memory and recomputation time of the always stored ops. - stored_indices: Set[int] = set() + stored_indices: set[int] = set() for s_idx in stored_ops: stored_indices.add(s_idx) if s_idx in inplace_op_groups: @@ -653,7 +653,7 @@ class SACEstimator(TorchDispatchMode): # save prediction graph def save_prediction_graph( - pwlf_: pwlf.PiecewiseLinFit, x: List[float], y: List[float], filename: str + pwlf_: pwlf.PiecewiseLinFit, x: list[float], y: list[float], filename: str ) -> None: try: import matplotlib.pyplot as plt # type: ignore[import-not-found] @@ -811,8 +811,8 @@ class SACEstimator(TorchDispatchMode): recomp_runtime: float = 0.0 def append_row( - op_indices: Set[int], - func_names: Set[str], + op_indices: set[int], + func_names: set[str], msps: Optional[float] = None, stored: Optional[bool] = False, recomputed: Optional[bool] = False, @@ -839,7 +839,7 @@ class SACEstimator(TorchDispatchMode): ) for op_idx in recomputed_ops: - op_indices: Set[int] = {op_idx} + op_indices: set[int] = {op_idx} if op_idx in inplace_op_groups: op_indices.update(inplace_op_groups[op_idx]) if op_idx in random_ops_group: diff --git a/torch/distributed/_tools/sac_ilp.py b/torch/distributed/_tools/sac_ilp.py index 8d8f3915473..63ff59184e3 100644 --- a/torch/distributed/_tools/sac_ilp.py +++ b/torch/distributed/_tools/sac_ilp.py @@ -1,7 +1,7 @@ import logging import math from enum import IntEnum -from typing import Dict, List, Optional +from typing import Optional from torch.distributed._tools.ilp_utils import Graph, is_submodule from torch.distributed._tools.sac_estimator import SACStats @@ -36,9 +36,9 @@ def sac_milp( graph: Graph, memory_budget: float, world_size: int = 1, - ac_units: Optional[List[str]] = None, - fsdp_units: Optional[List[str]] = None, -) -> tuple[Dict[str, float], float, int]: + ac_units: Optional[list[str]] = None, + fsdp_units: Optional[list[str]] = None, +) -> tuple[dict[str, float], float, int]: """ MILP to decide which modules to AC and how much memory to discard. The objective is to minimize recomputation time. @@ -224,7 +224,7 @@ class SACDecision(IntEnum): def get_optimal_checkpointing_policy_per_module( sac_stats: SACStats, memory_budget: float -) -> List[int]: +) -> list[int]: """ This is adapted from -- https://github.com/facebookresearch/xformers/blob/c6c0ac31f1b08542a0bc27278c6ed10f825f6963/xformers/checkpoint.py#L375 diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index dd3861c2651..98e213792b7 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -1,9 +1,10 @@ # mypy: allow-untyped-defs import warnings from abc import ABC, abstractmethod +from collections.abc import Iterator from enum import auto, Enum from functools import partial -from typing import Any, Callable, Dict, Iterator, Optional +from typing import Any, Callable, Optional import torch import torch.nn as nn @@ -69,10 +70,10 @@ class ActivationWrapper(torch.nn.Module, ABC): @staticmethod def _post_state_dict_hook( module: nn.Module, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, *args: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ _post_state_dict_hook() is called after the state_dict() of this FSDP module is executed. @@ -87,7 +88,7 @@ class ActivationWrapper(torch.nn.Module, ABC): @staticmethod def _pre_load_state_dict_hook( module: nn.Module, - state_dict: Dict[str, Any], + state_dict: dict[str, Any], prefix: str, *args: Any, ) -> None: diff --git a/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py b/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py index ada39ca24d9..569a42ffe76 100644 --- a/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py +++ b/torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import inspect from abc import ABC, abstractmethod -from typing import Dict, Type from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import ( @@ -15,7 +14,7 @@ from torch.optim import Optimizer # Contains the mappings between the regular and overlapped optimizer types. -_registered_overlapped_optims: Dict[Type, Type] = {} +_registered_overlapped_optims: dict[type, type] = {} def register_overlapped(optim_cls): @@ -33,7 +32,7 @@ def register_overlapped(optim_cls): class OverlappedOptimizer(ABC): - def __init__(self, optim_cls: Type) -> None: + def __init__(self, optim_cls: type) -> None: """ Initialize the OverlappedOptimizer. @@ -61,7 +60,7 @@ class OverlappedOptimizer(ABC): class _OverlappedStandardOptimizer(OverlappedOptimizer): """Overlaps a regular ``Optimizer``.""" - def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None: + def __init__(self, optim_cls: type, params, *optim_args, **optim_kwargs) -> None: super().__init__(optim_cls) f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs) self._opt_hook_state = _OptimizerHookState(f_optim, params) @@ -82,7 +81,7 @@ class _OverlappedStandardOptimizer(OverlappedOptimizer): ) -def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs): +def _as_overlapped_optim(optim_cls: type, params, *args, **kwargs): """Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``.""" for clz in inspect.getmro(optim_cls): try: diff --git a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py index 2f100061833..b7e60f2022b 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import weakref -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Optional import torch import torch.distributed as dist @@ -46,7 +46,7 @@ def _perform_local_step( # expects `None` in a list position to indicate that the corresponding # parameter should not be updated num_local_optim_params = len(zero.optim.param_groups[0]["params"]) - gradients: List[Optional[torch.Tensor]] = [ + gradients: list[Optional[torch.Tensor]] = [ _NO_PARAM_UPDATE for _ in range(num_local_optim_params) ] assert ( diff --git a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py index 9d155337d24..ae8136a1359 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py @@ -1,14 +1,14 @@ # mypy: allow-untyped-defs from dataclasses import dataclass from functools import partial -from typing import Any, Callable, List, no_type_check +from typing import Any, Callable, no_type_check import torch import torch.distributed as dist from torch.autograd import Variable -__all__: List[str] = [] +__all__: list[str] = [] _FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param" diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index ed19c7d7465..00b84d6c28e 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -2,7 +2,6 @@ import logging import math from collections import defaultdict -from typing import Dict import torch import torch.distributed as dist @@ -252,9 +251,9 @@ class PowerSGDState: self.rng = np.random.RandomState(random_seed) # Since there is only a single state instance for all the input buckets, # need to maintain a dictionary that maps each bucket index to the local error. - self.error_dict: Dict[int, torch.Tensor] = {} - self.p_memory_dict: Dict[int, torch.Tensor] = {} - self.q_memory_dict: Dict[int, torch.Tensor] = {} + self.error_dict: dict[int, torch.Tensor] = {} + self.p_memory_dict: dict[int, torch.Tensor] = {} + self.q_memory_dict: dict[int, torch.Tensor] = {} # Iteration/step in the training loop. self.iter = 0 # Compression stats accumulators diff --git a/torch/distributed/algorithms/join.py b/torch/distributed/algorithms/join.py index 90c679e2c2f..2bdbb5fff42 100644 --- a/torch/distributed/algorithms/join.py +++ b/torch/distributed/algorithms/join.py @@ -2,7 +2,7 @@ import warnings from abc import ABC, abstractmethod from types import TracebackType -from typing import Any, List, NamedTuple, Optional, Type +from typing import Any, NamedTuple, Optional import torch import torch.distributed as dist @@ -165,7 +165,7 @@ class Join: def __init__( self, - joinables: List[Joinable], + joinables: list[Joinable], enable: bool = True, throw_on_early_termination: bool = False, **kwargs, @@ -228,7 +228,7 @@ class Join: def __exit__( self, - type: Optional[Type[BaseException]], + type: Optional[type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType], ): diff --git a/torch/distributed/algorithms/model_averaging/averagers.py b/torch/distributed/algorithms/model_averaging/averagers.py index b1fedb32500..eec08464167 100644 --- a/torch/distributed/algorithms/model_averaging/averagers.py +++ b/torch/distributed/algorithms/model_averaging/averagers.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs import warnings from abc import ABC, abstractmethod -from typing import Dict, Iterable, Optional, Union +from collections.abc import Iterable +from typing import Optional, Union import torch import torch.distributed as dist @@ -107,7 +108,7 @@ class PeriodicModelAverager(ModelAverager): def average_parameters( self, params: Union[ - Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]] + Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]] ], ): """ diff --git a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py index a27f3b762a9..a52fc2babed 100644 --- a/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py +++ b/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py @@ -3,7 +3,8 @@ import logging import warnings from collections import OrderedDict -from typing import Dict, Iterable, Union +from collections.abc import Iterable +from typing import Union import torch import torch.distributed as dist @@ -159,7 +160,7 @@ class HierarchicalModelAverager(averagers.ModelAverager): def average_parameters( self, params: Union[ - Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]] + Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]] ], ): """ diff --git a/torch/distributed/algorithms/model_averaging/utils.py b/torch/distributed/algorithms/model_averaging/utils.py index c03a62f0620..0438043a6e7 100644 --- a/torch/distributed/algorithms/model_averaging/utils.py +++ b/torch/distributed/algorithms/model_averaging/utils.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs # flake8: noqa C101 import itertools -from typing import Dict, Iterable, Iterator, Union +from collections.abc import Iterable, Iterator +from typing import Union import torch import torch.distributed as dist @@ -51,7 +52,7 @@ def average_parameters( def get_params_to_average( - params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]] + params: Union[Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]] ): """ Return a list of parameters that need to average. @@ -81,7 +82,7 @@ def get_params_to_average( def average_parameters_or_parameter_groups( params: Union[ - Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]] + Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]] ], process_group: ProcessGroup, ): diff --git a/torch/distributed/c10d_logger.py b/torch/distributed/c10d_logger.py index c1361814cd3..c4dfb2b99e8 100644 --- a/torch/distributed/c10d_logger.py +++ b/torch/distributed/c10d_logger.py @@ -9,7 +9,7 @@ import functools import logging -from typing import Any, Callable, Dict, List, TypeVar +from typing import Any, Callable, TypeVar from typing_extensions import ParamSpec import torch @@ -18,7 +18,7 @@ from torch.distributed.logging_handlers import _log_handlers from torch.monitor import _WaitCounter -__all__: List[str] = [] +__all__: list[str] = [] _DEFAULT_DESTINATION = "default" @@ -48,7 +48,7 @@ global _c10d_logger _c10d_logger = _get_or_create_logger() -def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]: +def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]: if dist.is_initialized(): group = kwargs.get("group") or kwargs.get("process_group") msg_dict = { diff --git a/torch/distributed/collective_utils.py b/torch/distributed/collective_utils.py index 5e29807b3ad..b77e1ba8956 100644 --- a/torch/distributed/collective_utils.py +++ b/torch/distributed/collective_utils.py @@ -10,7 +10,7 @@ Each should also handle single rank scenario. from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, cast, Generic, List, Optional, TypeVar, Union +from typing import Any, Callable, cast, Generic, Optional, TypeVar, Union import torch.distributed as dist @@ -106,7 +106,7 @@ def all_gather( data_or_fn: Union[T, Callable[[], T]], stage_name: Optional[str] = None, pg: Optional[dist.ProcessGroup] = None, -) -> List[T]: +) -> list[T]: """ A simple all_gather primitive with basic synchronization guard logic, by checking payload from all ranks has the same stage name. @@ -149,11 +149,11 @@ def all_gather( all_gather_object_enforce_type(pg, total_list, sync_obj) # Each rank will throw RuntimeError in case of failure on any rank. stage_name = cast(SyncPayload[T], total_list[0]).stage_name - exception_list: List[tuple[int, Exception]] = [] - ret_list: List[T] = [] + exception_list: list[tuple[int, Exception]] = [] + ret_list: list[T] = [] error_msg: str = "" - for i, sp in enumerate(cast(List[SyncPayload[T]], total_list)): + for i, sp in enumerate(cast(list[SyncPayload[T]], total_list)): if sp.stage_name != stage_name: error_msg += ( f"Unexpected stage name received from rank {i}: {sp.stage_name} " @@ -183,7 +183,7 @@ def all_gather( def all_gather_object_enforce_type( pg: dist.ProcessGroup, # pyre-fixme[2]: Parameter must have a type that does not contain `Any` - object_list: List[Any], + object_list: list[Any], # pyre-fixme[2]: Parameter must have a type other than `Any` obj: Any, # pyre-fixme[2]: Parameter must have a type that does not contain `Any` diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 2282affea1c..d2a79a4d1c5 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -5,7 +5,7 @@ import math import threading from functools import reduce from itertools import chain -from typing import Dict, List, Optional, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING, Union import torch from torch.distributed import is_available @@ -65,15 +65,15 @@ else: class _MeshEnv(threading.local): def __init__(self) -> None: - self.mesh_stack: List[DeviceMesh] = [] - self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {} - self.mesh_dim_group_options: Dict[ + self.mesh_stack: list[DeviceMesh] = [] + self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {} + self.mesh_dim_group_options: dict[ int, tuple[str, Optional[C10dBackend.Options]] ] = {} - self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {} + self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {} # Record flatten mesh name to its mesh dim index in root mesh. - self.flatten_name_to_root_dims: Dict[ - DeviceMesh, Dict[str, tuple[int, ...]] + self.flatten_name_to_root_dims: dict[ + DeviceMesh, dict[str, tuple[int, ...]] ] = {} def get_current_mesh(self) -> "DeviceMesh": @@ -85,7 +85,7 @@ else: self, device_mesh: "DeviceMesh", submesh_dim_names: tuple[str, ...], - submesh_dims: List[tuple[int, ...]], + submesh_dims: list[tuple[int, ...]], ) -> "DeviceMesh": # Get the submesh dim size from the submesh_dims. # For example, if we have a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp", "cp", "tp") and we want @@ -286,7 +286,7 @@ else: def _get_slice_mesh_dims( self, device_mesh, mesh_dim_names - ) -> List[tuple[int, ...]]: + ) -> list[tuple[int, ...]]: """ Validate whether the mesh_dim_names is valid for slicing the given device_mesh. If valid, return dim indexes of the slice mesh in the device mesh. @@ -338,7 +338,7 @@ else: def _get_all_submeshes( self, device_mesh: "DeviceMesh", mesh_dim_name: str - ) -> List["DeviceMesh"]: + ) -> list["DeviceMesh"]: """ Return all the submeshes of a given mesh dimension of the device mesh. """ @@ -458,7 +458,7 @@ else: # calculate the coordinates of the current global rank on the mesh rank_coords = (self.mesh == get_rank()).nonzero() assert rank_coords.size(0) in (0, 1) - self._coordinate_on_dim: Optional[List[int]] = ( + self._coordinate_on_dim: Optional[list[int]] = ( rank_coords[0].tolist() if rank_coords.size(0) > 0 else None ) @@ -498,7 +498,7 @@ else: # TODO(yifu): remove tag and ranks once we fully migrate to native # functional collectives. See details in: # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208 - dim_group_infos: List[tuple[str, List[int], str]] = [] + dim_group_infos: list[tuple[str, list[int], str]] = [] default_group = _get_default_group() if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size(): @@ -775,7 +775,7 @@ else: _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) # type: ignore[index] ) - def get_all_groups(self) -> List[ProcessGroup]: + def get_all_groups(self) -> list[ProcessGroup]: """ Returns a list of ProcessGroups for all mesh dimensions. @@ -786,7 +786,7 @@ else: @staticmethod def from_group( - group: Union[ProcessGroup, List[ProcessGroup]], + group: Union[ProcessGroup, list[ProcessGroup]], device_type: str, mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None, *, @@ -910,7 +910,7 @@ else: ), "We expect ProcessGroup before calling `get_rank`!" return not_none(get_rank(mesh_dim_group)) - def get_coordinate(self) -> Optional[List[int]]: + def get_coordinate(self) -> Optional[list[int]]: """ Return the relative indices of this rank relative to all dimensions of the mesh. If this rank is not part of the mesh, return None. diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 7e834821c73..1b0442c1542 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -15,7 +15,7 @@ import time import warnings from collections import namedtuple from datetime import timedelta -from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union from typing_extensions import deprecated import torch @@ -259,18 +259,18 @@ class Backend(str): _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"]) - _plugins: Dict[str, _BackendPlugin] = {} + _plugins: dict[str, _BackendPlugin] = {} backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI] # 3rd-party devices can register the default backend support here - default_device_backend_map: Dict[str, str] = { + default_device_backend_map: dict[str, str] = { "cpu": GLOO, "cuda": NCCL, "xpu": XCCL, } - backend_capability: Dict[str, List[str]] = { + backend_capability: dict[str, list[str]] = { GLOO: ["cpu", "cuda"], NCCL: ["cuda"], XCCL: ["xpu"], @@ -278,7 +278,7 @@ class Backend(str): MPI: ["cpu", "cuda"], } - backend_type_map: Dict[str, ProcessGroup.BackendType] = { + backend_type_map: dict[str, ProcessGroup.BackendType] = { UNDEFINED: ProcessGroup.BackendType.UNDEFINED, GLOO: ProcessGroup.BackendType.GLOO, NCCL: ProcessGroup.BackendType.NCCL, @@ -303,7 +303,7 @@ class Backend(str): name, func, extended_api=False, - devices: Optional[Union[str, List[str]]] = None, + devices: Optional[Union[str, list[str]]] = None, ) -> None: """ Register a new backend with the given name and instantiating function. @@ -371,7 +371,7 @@ class BackendConfig: def __init__(self, backend: Backend): """Init.""" - self.device_backend_map: Dict[str, Backend] = {} + self.device_backend_map: dict[str, Backend] = {} backend = str(backend) if backend == Backend.UNDEFINED: @@ -441,7 +441,7 @@ class BackendConfig: f"{device}:{backend}" for device, backend in self.device_backend_map.items() ) - def get_device_backend_map(self) -> Dict[str, Backend]: + def get_device_backend_map(self) -> dict[str, Backend]: """Return backend map of the device.""" return self.device_backend_map @@ -572,14 +572,14 @@ class _CollOp: # DO NOT USE THESE FIELDS DIRECTLY. # Use them through the _world object to make sure the _world override mechanism -_pg_map: Dict[ProcessGroup, tuple[str, Store]] = {} -_pg_names: Dict[ProcessGroup, str] = {} -_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {} +_pg_map: dict[ProcessGroup, tuple[str, Store]] = {} +_pg_names: dict[ProcessGroup, str] = {} +_pg_group_ranks: dict[ProcessGroup, dict[int, int]] = {} # For a pg, it is a map from ProcessGroup to BackendConfig -_pg_backend_config: Dict[ProcessGroup, str] = {} +_pg_backend_config: dict[ProcessGroup, str] = {} _group_count = 0 -_tags_to_pg: Dict[str, List[ProcessGroup]] = {} -_pg_to_tag: Dict[ProcessGroup, str] = {} +_tags_to_pg: dict[str, list[ProcessGroup]] = {} +_pg_to_tag: dict[ProcessGroup, str] = {} _backend: Optional[str] = None @@ -595,7 +595,7 @@ class _World: def __init__(self) -> None: self._default_pg = None - self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {} + self._pg_coalesce_state: dict[ProcessGroup, list[_CollOp]] = {} @property def default_pg(self) -> Optional[ProcessGroup]: @@ -612,7 +612,7 @@ class _World: self._default_pg = value @property - def pg_map(self) -> Dict[ProcessGroup, tuple[str, Store]]: + def pg_map(self) -> dict[ProcessGroup, tuple[str, Store]]: """ Provide Mapping from ProcessGroup to backend name and store. @@ -625,7 +625,7 @@ class _World: return _pg_map @property - def pg_names(self) -> Dict[ProcessGroup, str]: + def pg_names(self) -> dict[ProcessGroup, str]: """ Process group's names, map from ProcessGroup to str. @@ -635,7 +635,7 @@ class _World: return _pg_names @property - def pg_group_ranks(self) -> Dict[ProcessGroup, Dict[int, int]]: + def pg_group_ranks(self) -> dict[ProcessGroup, dict[int, int]]: """ Process group's global rank to local rank mapping. @@ -645,7 +645,7 @@ class _World: return _pg_group_ranks @property - def pg_backend_config(self) -> Dict[ProcessGroup, str]: + def pg_backend_config(self) -> dict[ProcessGroup, str]: """ Process group's backend config. @@ -671,27 +671,27 @@ class _World: _group_count = value @property - def tags_to_pg(self) -> Dict[str, List[ProcessGroup]]: + def tags_to_pg(self) -> dict[str, list[ProcessGroup]]: global _tags_to_pg return _tags_to_pg @property - def pg_to_tag(self) -> Dict[ProcessGroup, str]: + def pg_to_tag(self) -> dict[ProcessGroup, str]: global _pg_to_tag return _pg_to_tag @property - def pg_coalesce_state(self) -> Dict[ProcessGroup, List[_CollOp]]: + def pg_coalesce_state(self) -> dict[ProcessGroup, list[_CollOp]]: return self._pg_coalesce_state @property - def pg_config_info(self) -> List[Dict[str, Any]]: + def pg_config_info(self) -> list[dict[str, Any]]: """ Return a list of dict with process groups and backends. Along with their unique IDs and configurations (types and ranks). """ - config_info: List[Dict[str, Any]] = [] + config_info: list[dict[str, Any]] = [] default_pg_size = _get_group_size(None) for pg in self.pg_map.keys(): ranks = self.pg_group_ranks[pg] @@ -912,7 +912,7 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device return rv -def _device_capability(group: Optional[ProcessGroup] = None) -> List[str]: +def _device_capability(group: Optional[ProcessGroup] = None) -> list[str]: """ Return the device type(s) supported by ``group``. @@ -1077,7 +1077,7 @@ def _get_global_rank(group, rank) -> int: return get_global_rank(group, rank) -def get_process_group_ranks(group: ProcessGroup) -> List[int]: +def get_process_group_ranks(group: ProcessGroup) -> list[int]: """ Get all ranks associated with ``group``. @@ -1103,7 +1103,7 @@ def _get_group_size_by_name(group_name: str) -> int: return group.size() -def _resolve_group_name_by_ranks_and_tag(ranks: List[int], tag: str) -> str: +def _resolve_group_name_by_ranks_and_tag(ranks: list[int], tag: str) -> str: # TODO(yifu): remove this function once ranks + tag is not a supported # identifier for process group for functional collectives. group = _find_pg_by_ranks_and_tag(tag, ranks) @@ -1401,7 +1401,7 @@ def _get_process_group_uid(pg: ProcessGroup) -> int: return -1 -def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]: +def _get_pg_config(group: Optional[ProcessGroup] = None) -> dict[str, Any]: """ Return the pg configuration of the given process group. @@ -1416,12 +1416,12 @@ def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]: } -def _get_all_pg_configs() -> List[Dict[str, Any]]: +def _get_all_pg_configs() -> list[dict[str, Any]]: """ Return the pg configuration of all the process groups. """ - config_info: List[Dict[str, Any]] = [ + config_info: list[dict[str, Any]] = [ _get_pg_config(pg) for pg in _world.pg_map.keys() ] return config_info @@ -2507,7 +2507,7 @@ class _IllegalWork(Work): class _CoalescingManager: def __init__(self) -> None: - self.works: List[Work] = [] + self.works: list[Work] = [] def append(self, work: Work): if work: @@ -2607,7 +2607,7 @@ def _coalescing_manager( work.wait() # type: ignore[possibly-undefined] -def batch_isend_irecv(p2p_op_list: List[P2POp]) -> List[Work]: +def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]: """ Send or Receive a batch of tensors asynchronously and return a list of requests. @@ -2651,7 +2651,7 @@ def batch_isend_irecv(p2p_op_list: List[P2POp]) -> List[Work]: group = p2p_op_list[0].group device = p2p_op_list[0].tensor.device - def peer_kwarg(op: P2POp) -> Dict[str, int]: + def peer_kwarg(op: P2POp) -> dict[str, int]: key = "group_dst" if op.op == isend else "group_src" return {key: op.group_peer} @@ -3054,7 +3054,7 @@ def all_gather_object(object_list, obj, group=None): @_exception_logger def gather_object( obj: Any, - object_gather_list: Optional[List[Any]] = None, + object_gather_list: Optional[list[Any]] = None, dst: Optional[int] = None, group: Optional[ProcessGroup] = None, group_dst: Optional[int] = None, @@ -3178,7 +3178,7 @@ def gather_object( @_exception_logger def send_object_list( - object_list: List[Any], + object_list: list[Any], dst: Optional[int] = None, group: Optional[ProcessGroup] = None, device: Optional[torch.device] = None, @@ -3277,7 +3277,7 @@ def send_object_list( @_exception_logger def recv_object_list( - object_list: List[Any], + object_list: list[Any], src: Optional[int] = None, group: Optional[ProcessGroup] = None, device: Optional[torch.device] = None, @@ -3379,7 +3379,7 @@ def recv_object_list( @_exception_logger def broadcast_object_list( - object_list: List[Any], + object_list: list[Any], src: Optional[int] = None, group: Optional[ProcessGroup] = None, device: Optional[torch.device] = None, @@ -3505,8 +3505,8 @@ def broadcast_object_list( @_exception_logger def scatter_object_list( - scatter_object_output_list: List[Any], - scatter_object_input_list: Optional[List[Any]] = None, + scatter_object_output_list: list[Any], + scatter_object_input_list: Optional[list[Any]] = None, src: Optional[int] = None, group: Optional[ProcessGroup] = None, group_src: Optional[int] = None, @@ -3933,7 +3933,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): @_exception_logger def gather( tensor: torch.Tensor, - gather_list: Optional[List[torch.Tensor]] = None, + gather_list: Optional[list[torch.Tensor]] = None, dst: Optional[int] = None, group: Optional[ProcessGroup] = None, async_op: bool = False, @@ -4013,7 +4013,7 @@ def gather( @_exception_logger def scatter( tensor: torch.Tensor, - scatter_list: Optional[List[torch.Tensor]] = None, + scatter_list: Optional[list[torch.Tensor]] = None, src: Optional[int] = None, group: Optional[ProcessGroup] = None, async_op: bool = False, @@ -4657,7 +4657,7 @@ def _create_process_group_wrapper( # helper function for deterministically hashing a list of ranks to a unique # string -def _hash_ranks_to_str(ranks: List[int]) -> str: +def _hash_ranks_to_str(ranks: list[int]) -> str: rank_join: str = "_".join(map(str, ranks)) # In case there is already a PG with the same rank composition unique_str = "_".join([rank_join, str(len(_world.pg_names))]) @@ -4665,7 +4665,7 @@ def _hash_ranks_to_str(ranks: List[int]) -> str: # Takes a list of ranks and computes an integer color -def _process_group_color(ranks: List[int]) -> int: +def _process_group_color(ranks: list[int]) -> int: # Convert list to tuple to make it hashable ranks = tuple(ranks) hash_value = hash(ranks) @@ -5315,7 +5315,7 @@ def new_subgroups_by_enumeration( return cur_subgroup, subgroups -def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGroup]: +def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGroup]: if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"): tag = f"user:{tag}" @@ -5331,7 +5331,7 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGro def _find_or_create_pg_by_ranks_and_tag( - tag: str, ranks: List[int], stride: int + tag: str, ranks: list[int], stride: int ) -> ProcessGroup: assert ( len(ranks) % stride == 0 diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index 05367ebeaff..d8e2017e7e1 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -9,7 +9,7 @@ import sys import uuid from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Optional, Union import torch.distributed.elastic.rendezvous.registry as rdzv_registry from torch.distributed.elastic import events, metrics @@ -79,13 +79,13 @@ class LaunchConfig: role: str = "default_role" rdzv_endpoint: str = "" rdzv_backend: str = "etcd" - rdzv_configs: Dict[str, Any] = field(default_factory=dict) + rdzv_configs: dict[str, Any] = field(default_factory=dict) rdzv_timeout: int = -1 max_restarts: int = 3 monitor_interval: float = 0.1 start_method: str = "spawn" log_line_prefix_template: Optional[str] = None - metrics_cfg: Dict[str, str] = field(default_factory=dict) + metrics_cfg: dict[str, str] = field(default_factory=dict) local_addr: Optional[str] = None def __post_init__(self): @@ -140,7 +140,7 @@ class elastic_launch: def _get_entrypoint_name( - entrypoint: Union[Callable, str, None], args: List[Any] + entrypoint: Union[Callable, str, None], args: list[Any] ) -> str: """Retrieve entrypoint name with the rule: 1. If entrypoint is a function, use ``entrypoint.__qualname__``. @@ -183,8 +183,8 @@ def _get_addr_and_port( def launch_agent( config: LaunchConfig, entrypoint: Union[Callable, str, None], - args: List[Any], -) -> Dict[int, Any]: + args: list[Any], +) -> dict[int, Any]: if not config.run_id: run_id = str(uuid.uuid4().int) logger.warning("config has no run_id, generated a random run_id: %s", run_id) diff --git a/torch/distributed/logging_handlers.py b/torch/distributed/logging_handlers.py index 021ad100f06..ed6832fd1ae 100644 --- a/torch/distributed/logging_handlers.py +++ b/torch/distributed/logging_handlers.py @@ -7,11 +7,10 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Dict, List -__all__: List[str] = [] +__all__: list[str] = [] -_log_handlers: Dict[str, logging.Handler] = { +_log_handlers: dict[str, logging.Handler] = { "default": logging.NullHandler(), } diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index 118c5a253a4..40fec71787a 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -4,20 +4,8 @@ import collections import io import sys import types -from typing import ( - Any, - Callable, - Dict, - Iterator, - List, - Mapping, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, -) +from collections.abc import Iterator, Mapping +from typing import Any, Callable, Optional, TypeVar, Union import torch import torch.distributed.rpc as rpc @@ -109,8 +97,8 @@ def _create_module_with_interface( return rpc.RRef(module, module_interface_cls) -def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]: - ret: List[rpc.RRef[Parameter]] = [ +def _param_rrefs(module_rref, recurse) -> list[rpc.RRef[Parameter]]: + ret: list[rpc.RRef[Parameter]] = [ rpc.RRef(param) for param in module_rref.local_value().parameters(recurse) ] return ret @@ -129,9 +117,9 @@ class _RemoteModule(nn.Module): def __init__( self, remote_device: str, - module_cls: Type[nn.Module], - args: Optional[Tuple] = None, - kwargs: Optional[Dict[str, Any]] = None, + module_cls: type[nn.Module], + args: Optional[tuple] = None, + kwargs: Optional[dict[str, Any]] = None, _module_interface_cls: Any = None, ): """ @@ -282,7 +270,7 @@ class _RemoteModule(nn.Module): self._install_generated_methods() self._check_attribute_picklability() - def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]: + def remote_parameters(self, recurse: bool = True) -> list[rpc.RRef[Parameter]]: """ Return a list of :class:`~torch.distributed.rpc.RRef` pointing to the remote module's parameters. @@ -371,8 +359,8 @@ class _RemoteModule(nn.Module): hook: Union[ Callable[[T, tuple[Any, ...]], Optional[Any]], Callable[ - [T, tuple[Any, ...], Dict[str, Any]], - Optional[tuple[Any, Dict[str, Any]]], + [T, tuple[Any, ...], dict[str, Any]], + Optional[tuple[Any, dict[str, Any]]], ], ], prepend: bool = False, @@ -384,7 +372,7 @@ class _RemoteModule(nn.Module): self, hook: Union[ Callable[[T, tuple[Any, ...], Any], Optional[Any]], - Callable[[T, tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], + Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]], ], prepend: bool = False, with_kwargs: bool = False, @@ -431,7 +419,7 @@ class _RemoteModule(nn.Module): def named_modules( self, - memo: Optional[Set[Module]] = None, + memo: Optional[set[Module]] = None, prefix: str = "", remove_duplicate: bool = True, ): @@ -681,9 +669,9 @@ class RemoteModule(_RemoteModule): def __init__( self, remote_device: str, - module_cls: Type[nn.Module], - args: Optional[Tuple] = None, - kwargs: Optional[Dict[str, Any]] = None, + module_cls: type[nn.Module], + args: Optional[tuple] = None, + kwargs: Optional[dict[str, Any]] = None, ): super().__init__(remote_device, module_cls, args, kwargs) diff --git a/torch/distributed/optim/apply_optimizer_in_backward.py b/torch/distributed/optim/apply_optimizer_in_backward.py index 36f679f4eba..741fe350121 100644 --- a/torch/distributed/optim/apply_optimizer_in_backward.py +++ b/torch/distributed/optim/apply_optimizer_in_backward.py @@ -1,9 +1,10 @@ -from typing import Any, Dict, Iterable, List, no_type_check, Type +from collections.abc import Iterable +from typing import Any, no_type_check import torch -__all__: List[str] = [] +__all__: list[str] = [] # WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter # without changing it's life-time. @@ -15,9 +16,9 @@ param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary() @no_type_check def _apply_optimizer_in_backward( - optimizer_class: Type[torch.optim.Optimizer], + optimizer_class: type[torch.optim.Optimizer], params: Iterable[torch.nn.Parameter], - optimizer_kwargs: Dict[str, Any], + optimizer_kwargs: dict[str, Any], register_hook: bool = True, ) -> None: """ @@ -97,7 +98,7 @@ def _apply_optimizer_in_backward( _apply_optimizer_in_backward_to_param(param) -def _get_in_backward_optimizers(module: torch.nn.Module) -> List[torch.optim.Optimizer]: +def _get_in_backward_optimizers(module: torch.nn.Module) -> list[torch.optim.Optimizer]: """ Return a list of in-backward optimizers applied to ``module``'s parameters. Note that these optimizers are not intended to directly have their ``step`` or ``zero_grad`` methods called @@ -113,7 +114,7 @@ def _get_in_backward_optimizers(module: torch.nn.Module) -> List[torch.optim.Opt _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {'lr': 0.01}) optims = _get_optimizers_in_backward(model) """ - optims: List[torch.optim.Optimizer] = [] + optims: list[torch.optim.Optimizer] = [] for param in module.parameters(): optims.extend(getattr(param, "_in_backward_optimizers", [])) diff --git a/torch/distributed/optim/functional_adadelta.py b/torch/distributed/optim/functional_adadelta.py index 81bff38773a..9af7bba4680 100644 --- a/torch/distributed/optim/functional_adadelta.py +++ b/torch/distributed/optim/functional_adadelta.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Dict, List, Optional +from typing import Optional import torch import torch.optim._functional as F @@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import ( ) -__all__: List[str] = [] +__all__: list[str] = [] # Define a TorchScript compatible Functional Adadelta Optimizer @@ -25,7 +25,7 @@ __all__: List[str] = [] class _FunctionalAdadelta: def __init__( self, - params: List[Tensor], + params: list[Tensor], lr: float = 1.0, rho: float = 0.9, eps: float = 1e-6, @@ -51,9 +51,9 @@ class _FunctionalAdadelta: # param group as it's not a common use case. self.param_group = {"params": params} - self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) - def step(self, gradients: List[Optional[Tensor]]): + def step(self, gradients: list[Optional[Tensor]]): params = self.param_group["params"] params_with_grad = [] grads = [] diff --git a/torch/distributed/optim/functional_adagrad.py b/torch/distributed/optim/functional_adagrad.py index 2c8ac898a45..5820a94183c 100644 --- a/torch/distributed/optim/functional_adagrad.py +++ b/torch/distributed/optim/functional_adagrad.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Dict, List, Optional +from typing import Optional import torch import torch.optim._functional as F @@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import ( ) -__all__: List[str] = [] +__all__: list[str] = [] # Define a TorchScript compatible Functional Adagrad Optimizer @@ -25,7 +25,7 @@ __all__: List[str] = [] class _FunctionalAdagrad: def __init__( self, - params: List[Tensor], + params: list[Tensor], lr: float = 1e-2, lr_decay: float = 0.0, weight_decay: float = 0.0, @@ -53,7 +53,7 @@ class _FunctionalAdagrad: self.foreach = foreach self.fused = fused self.maximize = maximize - self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) if len(params) == 0 and not _allow_empty_param_list: raise ValueError("optimizer got an empty parameter list") @@ -70,12 +70,12 @@ class _FunctionalAdagrad: "step": torch.tensor(0.0), } - def step(self, gradients: List[Optional[Tensor]]): + def step(self, gradients: list[Optional[Tensor]]): params = self.param_group["params"] params_with_grad = [] grads = [] state_sums = [] - state_steps: List[Tensor] = [] + state_steps: list[Tensor] = [] if len(params) != len(gradients): raise ValueError( diff --git a/torch/distributed/optim/functional_adam.py b/torch/distributed/optim/functional_adam.py index 60b6a5ea3ec..b736cd4d164 100644 --- a/torch/distributed/optim/functional_adam.py +++ b/torch/distributed/optim/functional_adam.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Dict, List, Optional +from typing import Optional import torch import torch.optim._functional as F @@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import ( ) -__all__: List[str] = [] +__all__: list[str] = [] # Define a TorchScript compatible Functional Adam Optimizer @@ -25,7 +25,7 @@ __all__: List[str] = [] class _FunctionalAdam: def __init__( self, - params: List[Tensor], + params: list[Tensor], lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, @@ -59,7 +59,7 @@ class _FunctionalAdam: self.maximize = maximize self.foreach = foreach self.fused = fused - self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) if len(params) == 0 and not _allow_empty_param_list: raise ValueError("optimizer got an empty parameter list") @@ -78,7 +78,7 @@ class _FunctionalAdam: exp_avgs = [] exp_avg_sqs = [] max_exp_avg_sqs = [] - state_steps: List[Tensor] = [] + state_steps: list[Tensor] = [] has_complex = torch.is_complex(param) if grad is not None: params_with_grad.append(param) @@ -128,14 +128,14 @@ class _FunctionalAdam: found_inf=None, ) - def step(self, gradients: List[Optional[Tensor]]): + def step(self, gradients: list[Optional[Tensor]]): params = self.param_group["params"] params_with_grad = [] grads = [] exp_avgs = [] exp_avg_sqs = [] max_exp_avg_sqs = [] - state_steps: List[Tensor] = [] + state_steps: list[Tensor] = [] has_complex = False if len(params) != len(gradients): diff --git a/torch/distributed/optim/functional_adamax.py b/torch/distributed/optim/functional_adamax.py index a9ddaa9df6e..9327eca3abf 100644 --- a/torch/distributed/optim/functional_adamax.py +++ b/torch/distributed/optim/functional_adamax.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Dict, List, Optional +from typing import Optional import torch import torch.optim._functional as F @@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import ( ) -__all__: List[str] = [] +__all__: list[str] = [] # Define a TorchScript compatible Functional Adamax Optimizer @@ -25,7 +25,7 @@ __all__: List[str] = [] class _FunctionalAdamax: def __init__( self, - params: List[Tensor], + params: list[Tensor], lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, @@ -55,7 +55,7 @@ class _FunctionalAdamax: } self.foreach = foreach self.maximize = maximize - self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) if len(params) == 0 and not _allow_empty_param_list: raise ValueError("optimizer got an empty parameter list") @@ -64,13 +64,13 @@ class _FunctionalAdamax: # param group as it's not a common use case. self.param_group = {"params": params} - def step(self, gradients: List[Optional[Tensor]]): + def step(self, gradients: list[Optional[Tensor]]): params = self.param_group["params"] params_with_grad = [] grads = [] exp_avgs = [] exp_infs = [] - state_steps: List[Tensor] = [] + state_steps: list[Tensor] = [] if len(params) != len(gradients): raise ValueError( diff --git a/torch/distributed/optim/functional_adamw.py b/torch/distributed/optim/functional_adamw.py index 25502f7cbf2..8d79cc0f27f 100644 --- a/torch/distributed/optim/functional_adamw.py +++ b/torch/distributed/optim/functional_adamw.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Dict, List, Optional +from typing import Optional import torch import torch.optim._functional as F @@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import ( ) -__all__: List[str] = [] +__all__: list[str] = [] # Define a TorchScript compatible Functional AdamW Optimizer @@ -25,7 +25,7 @@ __all__: List[str] = [] class _FunctionalAdamW: def __init__( self, - params: List[Tensor], + params: list[Tensor], lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, @@ -59,7 +59,7 @@ class _FunctionalAdamW: self.maximize = maximize self.foreach = foreach self.fused = fused - self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) if len(params) == 0 and not _allow_empty_param_list: raise ValueError("optimizer got an empty parameter list") @@ -74,7 +74,7 @@ class _FunctionalAdamW: exp_avgs = [] exp_avg_sqs = [] max_exp_avg_sqs = [] - state_steps: List[Tensor] = [] + state_steps: list[Tensor] = [] has_complex = torch.is_complex(param) if grad is not None: params_with_grad.append(param) @@ -129,14 +129,14 @@ class _FunctionalAdamW: has_complex=has_complex, ) - def step(self, gradients: List[Optional[Tensor]]): + def step(self, gradients: list[Optional[Tensor]]): params = self.param_group["params"] params_with_grad = [] grads = [] exp_avgs = [] exp_avg_sqs = [] max_exp_avg_sqs = [] - state_steps: List[Tensor] = [] + state_steps: list[Tensor] = [] if len(params) != len(gradients): raise ValueError( diff --git a/torch/distributed/optim/functional_rmsprop.py b/torch/distributed/optim/functional_rmsprop.py index 1362c1e635f..424c2276bff 100644 --- a/torch/distributed/optim/functional_rmsprop.py +++ b/torch/distributed/optim/functional_rmsprop.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Dict, List, Optional +from typing import Optional import torch import torch.optim._functional as F @@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import ( ) -__all__: List[str] = [] +__all__: list[str] = [] # Define a TorchScript compatible Functional RMSprop Optimizer @@ -25,7 +25,7 @@ __all__: List[str] = [] class _FunctionalRMSprop: def __init__( self, - params: List[Tensor], + params: list[Tensor], lr: float = 1e-2, alpha: float = 0.99, eps: float = 1e-8, @@ -55,9 +55,9 @@ class _FunctionalRMSprop: # param group as it's not a common use case. self.param_group = {"params": params} - self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) - def step(self, gradients: List[Optional[Tensor]]): + def step(self, gradients: list[Optional[Tensor]]): params = self.param_group["params"] params_with_grad = [] grads = [] diff --git a/torch/distributed/optim/functional_rprop.py b/torch/distributed/optim/functional_rprop.py index 8337c20f7e2..877ea6bddef 100644 --- a/torch/distributed/optim/functional_rprop.py +++ b/torch/distributed/optim/functional_rprop.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Dict, List, Optional +from typing import Optional import torch import torch.optim._functional as F @@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import ( ) -__all__: List[str] = [] +__all__: list[str] = [] # Define a TorchScript compatible Functional Rprop Optimizer @@ -25,7 +25,7 @@ __all__: List[str] = [] class _FunctionalRprop: def __init__( self, - params: List[Tensor], + params: list[Tensor], lr: float = 1e-2, etas: tuple[float, float] = (0.5, 1.2), step_sizes: tuple[float, float] = (1e-6, 50), @@ -49,9 +49,9 @@ class _FunctionalRprop: # param group as it's not a common use case. self.param_group = {"params": params} - self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) - def step(self, gradients: List[Optional[Tensor]]): + def step(self, gradients: list[Optional[Tensor]]): params = self.param_group["params"] params_with_grad = [] grads = [] diff --git a/torch/distributed/optim/functional_sgd.py b/torch/distributed/optim/functional_sgd.py index 4644b6a4030..e0a00cf02e9 100644 --- a/torch/distributed/optim/functional_sgd.py +++ b/torch/distributed/optim/functional_sgd.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Dict, List, Optional +from typing import Optional import torch import torch.optim._functional as F @@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import ( ) -__all__: List[str] = [] +__all__: list[str] = [] # Define a TorchScript compatible Functional SGD Optimizer @@ -25,7 +25,7 @@ __all__: List[str] = [] class _FunctionalSGD: def __init__( self, - params: List[Tensor], + params: list[Tensor], lr: float = 1e-2, momentum: float = 0.0, dampening: float = 0.0, @@ -47,7 +47,7 @@ class _FunctionalSGD: self.maximize = maximize self.foreach = foreach self.fused = fused - self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) if len(params) == 0 and not _allow_empty_param_list: raise ValueError("optimizer got an empty parameter list") @@ -67,7 +67,7 @@ class _FunctionalSGD: dampening = self.defaults["dampening"] lr = self.defaults["lr"] params = [param] - momentum_buffer_list: List[Optional[Tensor]] = [] + momentum_buffer_list: list[Optional[Tensor]] = [] grads = [] has_sparse_grad = False @@ -106,11 +106,11 @@ class _FunctionalSGD: if momentum_buffer is not None: state["momentum_buffer"] = momentum_buffer - def step(self, gradients: List[Optional[Tensor]]): + def step(self, gradients: list[Optional[Tensor]]): params = self.param_group["params"] params_with_grad = [] grads = [] - momentum_buffer_list: List[Optional[Tensor]] = [] + momentum_buffer_list: list[Optional[Tensor]] = [] lr = self.defaults["lr"] weight_decay = self.defaults["weight_decay"] momentum = self.defaults["momentum"] diff --git a/torch/distributed/optim/named_optimizer.py b/torch/distributed/optim/named_optimizer.py index 4130189967a..dbbd2ac9713 100644 --- a/torch/distributed/optim/named_optimizer.py +++ b/torch/distributed/optim/named_optimizer.py @@ -1,18 +1,9 @@ # mypy: allow-untyped-defs import logging import warnings +from collections.abc import Collection, Mapping from copy import deepcopy -from typing import ( - Any, - Callable, - Collection, - Dict, - List, - Mapping, - Optional, - overload, - Union, -) +from typing import Any, Callable, Optional, overload, Union import torch import torch.nn as nn @@ -21,7 +12,7 @@ from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -__all__: List[str] = [] +__all__: list[str] = [] logger = logging.getLogger(__name__) @@ -129,7 +120,7 @@ class _NamedOptimizer(optim.Optimizer): ) param_group["params"] = params - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: """ Return the ``state_dict`` of the optimizer. @@ -317,7 +308,7 @@ class _NamedOptimizer(optim.Optimizer): # Calling ``step`` will load the initial state for optimizer states. self.step(closure=None) - def _pre_load_state_dict(self, state_dict) -> Dict[str, Any]: + def _pre_load_state_dict(self, state_dict) -> dict[str, Any]: # TODO(chienchin): This API should be FSDP agnostic and should support # general user hooks. if isinstance(self.module, FSDP): @@ -326,7 +317,7 @@ class _NamedOptimizer(optim.Optimizer): ) return state_dict - def _post_state_dict(self, state_dict) -> Dict[str, Any]: + def _post_state_dict(self, state_dict) -> dict[str, Any]: # TODO(chienchin): This API should be FSDP agnostic and should support # general user hooks. if isinstance(self.module, FSDP): @@ -334,6 +325,6 @@ class _NamedOptimizer(optim.Optimizer): return state_dict -def _gen_param_group_key(param_keys: List[str]) -> str: +def _gen_param_group_key(param_keys: list[str]) -> str: """Concatenate all param keys as a unique indentifier for one param group.""" return "/".join(sorted(param_keys)) diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index a32e5505ffe..cb7fb8a26a2 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -3,7 +3,7 @@ import logging from collections import defaultdict from threading import Lock -from typing import List, Optional +from typing import Optional import torch import torch.distributed.autograd as dist_autograd @@ -52,7 +52,7 @@ class _ScriptLocalOptimizer(nn.Module): def step(self, autograd_ctx_id: int): all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) # apply functional optimizer step with a list of gradients - grads: List[Optional[Tensor]] = [ + grads: list[Optional[Tensor]] = [ all_local_grads[p] if p in all_local_grads else None for p in self._local_params ] diff --git a/torch/distributed/optim/utils.py b/torch/distributed/optim/utils.py index d2c75eee7e3..5d1272fb693 100644 --- a/torch/distributed/optim/utils.py +++ b/torch/distributed/optim/utils.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Type from torch import optim @@ -46,7 +45,7 @@ def register_functional_optim(key, optim): functional_optim_map[key] = optim -def as_functional_optim(optim_cls: Type, *args, **kwargs): +def as_functional_optim(optim_cls: type, *args, **kwargs): try: functional_cls = functional_optim_map[optim_cls] except KeyError as e: @@ -57,7 +56,7 @@ def as_functional_optim(optim_cls: Type, *args, **kwargs): return _create_functional_optim(functional_cls, *args, **kwargs) -def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs): +def _create_functional_optim(functional_optim_cls: type, *args, **kwargs): return functional_optim_cls( [], *args, diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index 72a13ef0959..b234fb72fad 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -11,7 +11,7 @@ import inspect import io import logging from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Set, Type, Union +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist @@ -156,7 +156,7 @@ class _DDPBucketAssignment: def __init__( self, bucket_index: int, - parameters: List[torch.Tensor], + parameters: list[torch.Tensor], offset: int, ): self.bucket_index = bucket_index @@ -239,20 +239,20 @@ class _OverlapInfo: self.shard_buckets: bool = False # Modified per bucket reconstruction - self.params_per_bucket: List[List[torch.Tensor]] = [] - self.params_per_rank: List[List[torch.Tensor]] = [[] for _ in range(world_size)] - self.offsets: Dict[int, int] = {} + self.params_per_bucket: list[list[torch.Tensor]] = [] + self.params_per_rank: list[list[torch.Tensor]] = [[] for _ in range(world_size)] + self.offsets: dict[int, int] = {} # Group Ranks - self.assigned_ranks_per_bucket: List[Set[int]] = [] + self.assigned_ranks_per_bucket: list[set[int]] = [] self.num_bucket_assignments: int = 0 self.total_size: Optional[int] = None # Modified per iteration - self.broadcast_handles: List[Any] = [] - self.bucket_indices_seen: List[int] = [] + self.broadcast_handles: list[Any] = [] + self.bucket_indices_seen: list[int] = [] # Used by `hook_with_zero_step()` - self.bucket_index_to_future: Dict[int, torch.futures.Future] = {} - self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {} + self.bucket_index_to_future: dict[int, torch.futures.Future] = {} + self.bucket_index_to_bucket: dict[int, dist.GradBucket] = {} def wait_for_broadcasts(self) -> None: r""" @@ -372,7 +372,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): def __init__( self, params, - optimizer_class: Type[Optimizer], + optimizer_class: type[Optimizer], process_group: Optional[Any] = None, parameters_as_bucket_view: bool = False, overlap_with_ddp: bool = False, @@ -395,15 +395,15 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): # `self.param_groups` # Internal data structures (`_cache` indicates lazily evaluated) - self._param_to_rank_cache: Dict[torch.Tensor, int] = {} - self._param_to_index_cache: Dict[torch.Tensor, int] = {} - self._partition_parameters_cache: List[List[Dict]] = [] - self._index_to_param_cache: List[torch.Tensor] = [] - self._device_to_params_per_rank_cache: Dict[ - torch.device, List[List[torch.Tensor]] + self._param_to_rank_cache: dict[torch.Tensor, int] = {} + self._param_to_index_cache: dict[torch.Tensor, int] = {} + self._partition_parameters_cache: list[list[dict]] = [] + self._index_to_param_cache: list[torch.Tensor] = [] + self._device_to_params_per_rank_cache: dict[ + torch.device, list[list[torch.Tensor]] ] = {} - self._bucket_assignments_per_rank_cache: List[ - Dict[int, _DDPBucketAssignment] + self._bucket_assignments_per_rank_cache: list[ + dict[int, _DDPBucketAssignment] ] = [] self._is_trainable_mask = self._get_is_trainable_mask() @@ -439,12 +439,12 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): # `self._buckets` is used if `parameters_as_bucket_view=True`, in # which case parameter data is flattened into contiguous bucket tensors self.parameters_as_bucket_view = parameters_as_bucket_view - self._buckets: List[List[torch.Tensor]] = [] + self._buckets: list[list[torch.Tensor]] = [] self._build_param_buckets() # Optional consolidated optimizer state, only populated if this rank # is the target in `consolidate_state_dict()` - self._all_state_dicts: List[Dict[str, Any]] = [] + self._all_state_dicts: list[dict[str, Any]] = [] self.initialized = True @@ -457,7 +457,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): self._device_to_params_per_rank_cache.clear() self._bucket_assignments_per_rank_cache.clear() - def add_param_group(self, param_group: Dict[str, Any]) -> None: + def add_param_group(self, param_group: dict[str, Any]) -> None: r""" Add a parameter group to the :class:`Optimizer` 's ``param_groups``. @@ -586,7 +586,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): def _verify_params_per_rank( self, - params_per_rank: List[List[torch.Tensor]], + params_per_rank: list[list[torch.Tensor]], ) -> None: r""" Verify ``params_per_rank`` for :meth:`_partition_parameters`. @@ -619,7 +619,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): ) def _partition_param_group( - self, param_group: Dict[str, Any], params_per_rank: List[List[torch.Tensor]] + self, param_group: dict[str, Any], params_per_rank: list[list[torch.Tensor]] ) -> None: r""" Partition the parameter group ``param_group`` according to ``params_per_rank``. @@ -641,8 +641,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): def _partition_parameters( self, - params_per_rank: Optional[List[List[torch.Tensor]]] = None, - ) -> List[List[Dict]]: + params_per_rank: Optional[list[list[torch.Tensor]]] = None, + ) -> list[list[dict]]: r""" Partitions parameters across distributed data parallel ranks. @@ -674,7 +674,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): self._partition_parameters_cache = [[] for _ in range(self.world_size)] sizes = [0] * self.world_size for param_group in self.param_groups: - param_group_params_per_rank: List[List] = [ + param_group_params_per_rank: list[list] = [ [] for _ in range(self.world_size) ] # Sort the parameters by size (largest first) @@ -712,7 +712,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): return self._partition_parameters_cache @property - def _param_to_rank(self) -> Dict[torch.Tensor, int]: + def _param_to_rank(self) -> dict[torch.Tensor, int]: r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition.""" if len(self._param_to_rank_cache) == 0: for rank, param_groups in enumerate(self._partition_parameters()): @@ -722,7 +722,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): return self._param_to_rank_cache @property - def _param_to_index(self) -> Dict[torch.Tensor, int]: + def _param_to_index(self) -> dict[torch.Tensor, int]: r""" :class:`dict` mapping parameters to their indices in the global optimizer state. @@ -737,7 +737,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): return self._param_to_index_cache @property - def _index_to_param(self) -> List[torch.Tensor]: + def _index_to_param(self) -> list[torch.Tensor]: r"""List mapping parameter indices in the global optimizer scheme to the actual params.""" if len(self._index_to_param_cache) == 0: self._index_to_param_cache = list( @@ -811,7 +811,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): @property def _device_to_params_per_rank( self, - ) -> Dict[torch.device, List[List[torch.Tensor]]]: + ) -> dict[torch.device, list[list[torch.Tensor]]]: r""" Return device parameters assigned per rank. @@ -854,8 +854,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): def _get_min_index( self, - values: List[int], - disallowed_indices: Optional[Set[int]] = None, + values: list[int], + disallowed_indices: Optional[set[int]] = None, ) -> int: r""" Return ``values.index(min(values))``, except only uses one pass. @@ -881,10 +881,10 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): def _assign_bucket_subset_to_rank( self, bucket_index: int, - bucket_params: List[torch.Tensor], + bucket_params: list[torch.Tensor], bucket_offset: int, assigned_rank: int, - assigned_ranks_per_bucket: List[Set[int]], + assigned_ranks_per_bucket: list[set[int]], ) -> None: r""" Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information. @@ -919,7 +919,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): self._overlap_info.num_bucket_assignments += 1 @property - def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]: + def _bucket_assignments_per_rank(self) -> list[dict[int, _DDPBucketAssignment]]: r""" Return DDP bucket parameters assigned per rank. @@ -1015,7 +1015,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): def _local_step( self, - gradients: Optional[List[Optional[torch.Tensor]]] = None, + gradients: Optional[list[Optional[torch.Tensor]]] = None, closure: Optional[Callable[[], float]] = None, **kwargs: Any, ) -> Optional[float]: @@ -1148,7 +1148,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): r"""Return process group.""" return self.process_group - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: dict[str, Any]) -> None: r""" Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed. @@ -1186,7 +1186,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): self._sync_param_groups(state_dict["param_groups"], self.param_groups) self._sync_param_groups(self.param_groups, self.optim.param_groups) - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> dict[str, Any]: r""" Return the last global optimizer state known to this rank. @@ -1252,8 +1252,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): @staticmethod def _sync_param_groups( - src_param_groups: List[Dict[Any, Any]], - dst_param_groups: List[Dict[Any, Any]], + src_param_groups: list[dict[Any, Any]], + dst_param_groups: list[dict[Any, Any]], ) -> None: r""" Sync the attributes from the source parameter groups to the destination parameter groups. @@ -1381,7 +1381,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): def _verify_and_init_params( self, params: Any, - ) -> Union[List[torch.Tensor], List[dict]]: + ) -> Union[list[torch.Tensor], list[dict]]: r""" Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters. @@ -1469,7 +1469,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): f"{other_typename}" ) - def _get_is_trainable_mask(self) -> List[bool]: + def _get_is_trainable_mask(self) -> list[bool]: r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not.""" return list(map(_is_trainable, self._all_params)) diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 8669f49baca..5bffdccf934 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -7,7 +7,7 @@ from collections import defaultdict from enum import Enum from inspect import Parameter, Signature, signature from types import MethodType -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Optional, Union import torch import torch.fx as fx @@ -116,7 +116,7 @@ def _insert_stage_symbolic_backward( output_node: fx.Node, ): # Collect metadata about tuple output values. TODO: move this to split_module or FX IR - tuples: Dict[fx.Node, Tuple] = {} + tuples: dict[fx.Node, tuple] = {} for node in reversed(g.nodes): if node.op == "call_function": # In the forward pass, only emit placeholder, module calls, and @@ -155,7 +155,7 @@ def _insert_stage_symbolic_backward( # We will only emit backward operations for nodes that can contribute # to the specified loss value. live_nodes = {loss_node: None} - val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None} + val_to_grad: dict[fx.Node, Optional[fx.Node]] = {loss_node: None} def assign_or_accumulate_grad(forward_node, grad_value): if forward_node in val_to_grad and forward_node.op != "placeholder": @@ -349,7 +349,7 @@ class MultiUseParameterConfig(Enum): REPLICATE = 2 -MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]] +MultiUseParamSpec = Union[MultiUseParameterConfig, dict[str, MultiUseParameterConfig]] class DetachExecutor(fx.Interpreter): @@ -432,7 +432,7 @@ class _LinearNodeList: def to_graph(self): graph = fx.Graph() - ref_str_to_node: Dict[str, fx.Node] = {} + ref_str_to_node: dict[str, fx.Node] = {} def ref_to_node(arg): if isinstance(arg, _NodeReference): @@ -557,14 +557,14 @@ class Pipe(torch.nn.Module): # Map parameter value to a dictionary that maps the user pipeline module # to the local qualname within that module - params_to_users: Dict[torch.nn.Parameter, Dict[str, str]] = {} + params_to_users: dict[torch.nn.Parameter, dict[str, str]] = {} for m_qualname, mod in self.split_gm.named_children(): for p_qualname, param in mod.named_parameters(): params_to_users.setdefault(param, {}) params_to_users[param][m_qualname] = p_qualname - self.replicated_params: List[Dict[str, str]] = [ + self.replicated_params: list[dict[str, str]] = [ use_mapping for _, use_mapping in params_to_users.items() if len(use_mapping) > 1 @@ -645,7 +645,7 @@ class Pipe(torch.nn.Module): @staticmethod def _number_and_count_forward_stages(gm: fx.GraphModule): num_stages = 0 - found_idxs: Dict[int, None] = {} + found_idxs: dict[int, None] = {} for node in gm.graph.nodes: if node.op == "call_module" and node.target.startswith("submod_"): node.meta["stage_idx"] = int(node.target[len("submod_") :]) @@ -693,7 +693,7 @@ class Pipe(torch.nn.Module): # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving # parameters relies on the invariant that parameter accesses happen once. This is not necessarily # the case (especially with custom tracers), so fix that up here. - get_attr_nodes: Dict[str, fx.Node] = {} + get_attr_nodes: dict[str, fx.Node] = {} for node in traced.graph.nodes: # type: ignore[union-attr] if node.op == "get_attr": get_attr_nodes.setdefault(node.target, node) @@ -868,7 +868,7 @@ class Pipe(torch.nn.Module): # [aliasing] store tensor id -> list of FQNs, built from state dict # Also assign non-persistent buffers - id_to_fqns: Dict[int, Set[str]] = defaultdict(set) + id_to_fqns: dict[int, set[str]] = defaultdict(set) for fqn, tensor in mod.state_dict(keep_vars=True).items(): id_to_fqns[id(tensor)].add(fqn) for fqn, tensor in mod.named_buffers(): @@ -878,7 +878,7 @@ class Pipe(torch.nn.Module): # need to move the `get_attr` nodes from the root of the graph to those # hierarchies. # [aliasing] use id -> fqn mapping to list out all valid FQNs - inputs_to_state: Dict[str, List[str]] = {} + inputs_to_state: dict[str, list[str]] = {} for attr in attr_nodes: _, tensor = _recursive_getattr_with_parent(mod, attr.target) fqns = list(id_to_fqns[id(tensor)]) @@ -890,7 +890,7 @@ class Pipe(torch.nn.Module): # [aliasing] for each submodule split, assign attributes on FQNs that may be used. # We determine this based on whether or not the FQN attribute parent exists. # i.e. if the last submodule exists, assign the attribute. - added_attributes: Dict[str, List[str]] = defaultdict(list) + added_attributes: dict[str, list[str]] = defaultdict(list) for fqn, tensor in mod.state_dict(keep_vars=True).items(): for name, submod in split.named_children(): if isinstance(submod, fx.GraphModule): @@ -999,7 +999,7 @@ class Pipe(torch.nn.Module): def _trace_with_export( mod: torch.nn.Module, example_args: tuple[Any, ...], - example_kwargs: Optional[Dict[str, Any]] = None, + example_kwargs: Optional[dict[str, Any]] = None, ) -> ExportedProgram: logger.info("Tracing model ...") try: @@ -1023,7 +1023,7 @@ class Pipe(torch.nn.Module): def from_tracing( mod: torch.nn.Module, example_args: tuple[Any, ...], - example_kwargs: Optional[Dict[str, Any]] = None, + example_kwargs: Optional[dict[str, Any]] = None, split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, ): # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across @@ -1173,7 +1173,7 @@ def _split_after_forward(self, *args, **kwargs): pipe_split() -def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]): +def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]): # TODO: make this implementation out-of-place? for qualname, split_type in spec.items(): atoms = qualname.split(".") @@ -1200,8 +1200,8 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]): def pipeline( module: torch.nn.Module, mb_args: tuple[Any, ...], - mb_kwargs: Optional[Dict[str, Any]] = None, - split_spec: Optional[Dict[str, SplitPoint]] = None, + mb_kwargs: Optional[dict[str, Any]] = None, + split_spec: Optional[dict[str, SplitPoint]] = None, split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, ) -> Pipe: """ diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index e2744b32d85..e2eebf49ad7 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -2,7 +2,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import collections import logging -from typing import Any, Deque, Dict, Iterator, List, Optional, Set, Union +from collections.abc import Iterator +from typing import Any, Optional, Union import torch from torch.autograd.graph import GradientEdge, Node @@ -37,8 +38,8 @@ def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]: def reverse_closure( - roots: List[Node], target_nodes: Set[Node], reverse_edges_dict -) -> tuple[Set[Node], Set[Node]]: + roots: list[Node], target_nodes: set[Node], reverse_edges_dict +) -> tuple[set[Node], set[Node]]: """ This function returns the reverse closure of the given roots, i.e. the set of nodes that can be reached from the roots by following the @@ -46,9 +47,9 @@ def reverse_closure( include in the closure. """ # Recurse until we reach a target node - closure: Set[Node] = set() + closure: set[Node] = set() visited_target_nodes = set() - q: Deque[Node] = collections.deque() + q: collections.deque[Node] = collections.deque() for node in roots: if node is not None and node not in closure: closure.add(node) @@ -67,10 +68,10 @@ def reverse_closure( return closure, visited_target_nodes -def construct_reverse_graph(roots: List[Node]) -> Dict[Node, List[Node]]: - q: Deque[Node] = collections.deque() - root_seen: Set[Node] = set() - reverse_edges_dict: Dict[Node, List[Node]] = collections.defaultdict(list) +def construct_reverse_graph(roots: list[Node]) -> dict[Node, list[Node]]: + q: collections.deque[Node] = collections.deque() + root_seen: set[Node] = set() + reverse_edges_dict: dict[Node, list[Node]] = collections.defaultdict(list) for node in roots: if node is not None and node not in root_seen: q.append(node) @@ -86,8 +87,8 @@ def construct_reverse_graph(roots: List[Node]) -> Dict[Node, List[Node]]: def get_param_groups( - inputs: List[Node], params: List[Node], reverse_edges_dict -) -> List[Dict[str, Any]]: + inputs: list[Node], params: list[Node], reverse_edges_dict +) -> list[dict[str, Any]]: """ Given a list of inputs and a list of parameters, return a list of parameter groups, where each group contains the parameters and the intermediates that @@ -103,12 +104,12 @@ def get_param_groups( # reverse graph that starts with inputs, and goes up to the dOutput or the loss, # but omits weights and any subgraphs connecting weights to this closure inputs_closure, _ = reverse_closure(inputs, set(), reverse_edges_dict) - param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates + param_groups: dict[Node, dict[str, set]] = dict() # keyed on intermediates for param in params: closure, intersected = reverse_closure( [param], inputs_closure, reverse_edges_dict ) - param_group: Dict[str, Set] = { + param_group: dict[str, set] = { "params": {param}, "intermediates": intersected, } @@ -124,8 +125,8 @@ def get_param_groups( param_groups[input_node] = param_group # Sanity check: union of all param_groups params should be equal to all params - union_params: Set[Node] = set() - seen_ids: Set[int] = set() + union_params: set[Node] = set() + seen_ids: set[int] = set() unique_param_groups = [] for param_group in param_groups.values(): if id(param_group) not in seen_ids: @@ -140,11 +141,11 @@ def get_param_groups( def stage_backward_input( - stage_outputs_or_loss: List[torch.Tensor], - output_grads: Optional[List[torch.Tensor]], - input_values: List[torch.Tensor], + stage_outputs_or_loss: list[torch.Tensor], + output_grads: Optional[list[torch.Tensor]], + input_values: list[torch.Tensor], weights: Iterator[Parameter], -) -> tuple[tuple[Optional[torch.Tensor], ...], List[Dict[str, Any]]]: +) -> tuple[tuple[Optional[torch.Tensor], ...], list[dict[str, Any]]]: """ Compute the gradients for only the stage inputs with respect to the stage outputs (if non-last stage) or loss (if last stage) @@ -155,13 +156,13 @@ def stage_backward_input( Detaching the stage_outputs_or_loss at the end of this function is important as it frees up the memory that the autograd graph is anticipating to be used later (but doesn't actually need). """ - stage_output_grad_fns: List[Node] = list( + stage_output_grad_fns: list[Node] = list( filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs_or_loss)) ) - stage_input_grad_fns: List[Node] = list( + stage_input_grad_fns: list[Node] = list( filter(None, map(_get_grad_fn_or_grad_acc, input_values)) ) - weight_grad_fns: List[Node] = list( + weight_grad_fns: list[Node] = list( filter(None, map(_get_grad_fn_or_grad_acc, weights)) ) @@ -222,11 +223,11 @@ def stage_backward_input( def stage_backward_weight( - weights: Iterator[Parameter], param_groups: List[Dict[str, Any]], retain_graph=False + weights: Iterator[Parameter], param_groups: list[dict[str, Any]], retain_graph=False ) -> tuple[Optional[torch.Tensor], ...]: # map weights to param_group_weights grad_acc_to_weight = {} - weight_grads: List[Optional[torch.Tensor]] = [] + weight_grads: list[Optional[torch.Tensor]] = [] for index, weight in enumerate(weights): grad_acc = _get_grad_fn_or_grad_acc(weight) grad_acc_to_weight[grad_acc] = weight, index @@ -273,7 +274,7 @@ def stage_backward( stage_output, output_grads, input_values, - outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used + outputs_with_grads_idxs: Optional[list[int]] = None, # deprecated, not used ) -> tuple[Optional[torch.Tensor], ...]: """ This is a helper function to: @@ -293,8 +294,8 @@ def stage_backward( try: # stage_output may be a composite datatype like dict. Extract all individual # tensor values here - stage_output_tensors: List[torch.Tensor] = [] - output_grad_tensors: List[Optional[torch.Tensor]] = [] + stage_output_tensors: list[torch.Tensor] = [] + output_grad_tensors: list[Optional[torch.Tensor]] = [] def extract_tensors_with_grads( output_val, @@ -353,7 +354,7 @@ def stage_backward( ) # Extract gradients wrt the input values - grad_inputs: List[Optional[torch.Tensor]] = [] + grad_inputs: list[Optional[torch.Tensor]] = [] for val in input_values: if isinstance(val, torch.Tensor): grad_inputs.append(val.grad) diff --git a/torch/distributed/pipelining/_unflatten.py b/torch/distributed/pipelining/_unflatten.py index e24fac5fb4e..0ed592f2f8d 100644 --- a/torch/distributed/pipelining/_unflatten.py +++ b/torch/distributed/pipelining/_unflatten.py @@ -1,6 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from collections import defaultdict -from typing import Dict, List, Set import torch from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry @@ -9,10 +8,10 @@ from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry def _outline_submodules(orig_graph: torch.fx.Graph) -> torch.fx.GraphModule: # Create an empty GraphModule to hold the outlined modules new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) - seen_nodes: Dict[str, torch.fx.Node] = {} - seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list) - seen_attrs: Dict[str, Set[str]] = defaultdict(set) - created_modules: Dict[str, torch.nn.Module] = {} + seen_nodes: dict[str, torch.fx.Node] = {} + seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list) + seen_attrs: dict[str, set[str]] = defaultdict(set) + created_modules: dict[str, torch.nn.Module] = {} _ModuleFrame( orig_graph, tuple(orig_graph.nodes), diff --git a/torch/distributed/pipelining/_utils.py b/torch/distributed/pipelining/_utils.py index b5c53f9fa55..7e8977c8029 100644 --- a/torch/distributed/pipelining/_utils.py +++ b/torch/distributed/pipelining/_utils.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import logging from dataclasses import dataclass -from typing import List, Union +from typing import Union import torch from torch import fx @@ -75,8 +75,8 @@ def validate_tensor_metadata(desc, expected, given): def validate_tensors_metadata( desc, - expected_tensors: Union[List[torch.Tensor], tuple[torch.Tensor, ...]], - actual_tensors: Union[List[torch.Tensor], tuple[torch.Tensor, ...]], + expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], + actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], ): if len(expected_tensors) != len(actual_tensors): raise PipeliningShapeError( diff --git a/torch/distributed/pipelining/microbatch.py b/torch/distributed/pipelining/microbatch.py index e4ab1aa7679..a495c639101 100644 --- a/torch/distributed/pipelining/microbatch.py +++ b/torch/distributed/pipelining/microbatch.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional import torch from torch.fx.node import map_aggregate @@ -92,7 +92,7 @@ class TensorChunkSpec: @staticmethod def from_dict( - chunk_dims: Dict[str, int], + chunk_dims: dict[str, int], ): """ A helper for creating a dictionary of `TensorChunkSpec` from a @@ -243,11 +243,11 @@ def _shard_dict_of_args( def split_args_kwargs_into_chunks( args: tuple[Any, ...], - kwargs: Optional[Dict[str, Any]], + kwargs: Optional[dict[str, Any]], chunks: int, args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, -) -> tuple[List[Tuple], List[Dict]]: + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, +) -> tuple[list[tuple], list[dict]]: """ Given a sequence of args and kwargs, split them into a number of chunks according to their respective chunking specs. @@ -347,7 +347,7 @@ def split_args_kwargs_into_chunks( def merge_chunks( - chunks: List[Any], + chunks: list[Any], chunk_spec, ): """ diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index b4b0ab624b0..c3934fe5546 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -9,17 +9,7 @@ import re from abc import ABC, abstractmethod from collections import Counter, defaultdict from enum import Enum -from typing import ( - Any, - Callable, - Dict, - List, - NamedTuple, - Optional, - Set, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union import torch import torch.distributed as dist @@ -163,7 +153,7 @@ class _Action(NamedTuple): def _format_pipeline_order( - pipeline_order: Dict[int, List[Optional[_Action]]], + pipeline_order: dict[int, list[Optional[_Action]]], error_step_number: Optional[int] = None, ) -> str: """ @@ -230,8 +220,8 @@ class _PipelineSchedule(ABC): n_microbatches: int, loss_fn: Optional[Callable[..., torch.Tensor]] = None, args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, scale_grads: bool = True, ): # From arguments @@ -256,7 +246,7 @@ class _PipelineSchedule(ABC): self._has_backward = self._loss_fn is not None # Holds the losses for each microbatch. - self._internal_losses: List[torch.Tensor] = [] + self._internal_losses: list[torch.Tensor] = [] logger.info("Using %s", self.__class__.__name__) def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): @@ -302,10 +292,10 @@ class _PipelineSchedule(ABC): @abstractmethod def _step_microbatches( self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, ): """ Run one iteration of the pipeline schedule with list of microbatches. @@ -318,7 +308,7 @@ class _PipelineSchedule(ABC): raise NotImplementedError @abstractmethod - def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): """ Run one iteration of the pipeline schedule with *whole-batch* input. Will chunk the input into microbatches automatically, and go through the @@ -333,10 +323,10 @@ class _PipelineSchedule(ABC): def _check_inputs( self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, ): """ Pre-process/check inputs @@ -375,7 +365,7 @@ class _PipelineSchedule(ABC): def _split_inputs( self, args: tuple[Any, ...], - kwargs: Optional[Dict[str, Any]] = None, + kwargs: Optional[dict[str, Any]] = None, ): """ Splits a full-batch input into chunks (i.e. microbatches) and returns @@ -395,7 +385,7 @@ class _PipelineSchedule(ABC): # Return a list of empty tuples/dicts with matching length as chunks return [()] * self._n_microbatches, [{}] * self._n_microbatches - def _merge_outputs(self, output_chunks: List[Any]) -> Any: + def _merge_outputs(self, output_chunks: list[Any]) -> Any: """ Merge output chunks back to a batch state. If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim). @@ -406,7 +396,7 @@ class _PipelineSchedule(ABC): ) -def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None): +def _batch_p2p(p2p_ops: list[dist.P2POp], desc: Optional[str] = None): """ Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top. """ @@ -418,8 +408,8 @@ def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None): def _sorted_batch_p2p( - p2p_ops: List[dist.P2POp], desc: Optional[str] = None -) -> Dict[int, dist.Work]: + p2p_ops: list[dist.P2POp], desc: Optional[str] = None +) -> dict[int, dist.Work]: """ Sorts the list of P2P ops by the peer rank, and then calls batch_isend_irecv. Return a dictionary of works by peer rank. This function @@ -428,8 +418,8 @@ def _sorted_batch_p2p( # Arrange p2p_ops by peer rank: # int is the peer rank; # List is the list of ops towards the peer - ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list) - work_by_peer: Dict[int, dist.Work] = {} + ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list) + work_by_peer: dict[int, dist.Work] = {} if len(p2p_ops) == 0: return work_by_peer @@ -461,8 +451,8 @@ class PipelineScheduleSingle(_PipelineSchedule): n_microbatches: int, loss_fn: Optional[Callable] = None, args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, scale_grads: bool = True, ): # Init parent @@ -493,7 +483,7 @@ or equal to the number of stages ({self._num_stages})." self._stage._prepare_backward_infra(self._n_microbatches) self._stage_initialized = True - def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): """ Run one iteration of the pipeline schedule with *whole-batch* input. Will chunk the input into microbatches automatically, and go through the @@ -535,10 +525,10 @@ class _ScheduleForwardOnly(PipelineScheduleSingle): def _step_microbatches( self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, ): """ Run one iteration of the pipeline schedule @@ -553,7 +543,7 @@ class _ScheduleForwardOnly(PipelineScheduleSingle): self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) # Delay send waits - fwd_sends_to_wait: List[dist.Work] = [] + fwd_sends_to_wait: list[dist.Work] = [] # Run microbatches for i in range(self._n_microbatches): @@ -586,10 +576,10 @@ class ScheduleGPipe(PipelineScheduleSingle): def _step_microbatches( self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, ): """ Run one iteration of the pipeline schedule with list of microbatches. @@ -604,7 +594,7 @@ class ScheduleGPipe(PipelineScheduleSingle): self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) # Delay send waits - fwd_sends_to_wait: List[dist.Work] = [] + fwd_sends_to_wait: list[dist.Work] = [] # Run microbatches for i in range(self._n_microbatches): @@ -636,7 +626,7 @@ class ScheduleGPipe(PipelineScheduleSingle): # Run backward # Delay send waits - bwd_sends_to_wait: List[dist.Work] = [] + bwd_sends_to_wait: list[dist.Work] = [] for i in range(self._n_microbatches): with record_function(f"Backward {i}"): ops = self._stage.get_bwd_recv_ops(i) @@ -677,10 +667,10 @@ class Schedule1F1B(PipelineScheduleSingle): def _step_microbatches( self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, ): """ Run one iteration of the pipeline schedule with list of microbatches. @@ -820,9 +810,9 @@ class Schedule1F1B(PipelineScheduleSingle): def _add_unshard_reshard( - compute_actions: List[Optional[_Action]], + compute_actions: list[Optional[_Action]], max_active_stages: int = 3, -) -> List[_Action]: +) -> list[_Action]: """Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP. UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation. @@ -836,11 +826,11 @@ def _add_unshard_reshard( """ def next_stage_indices( - count: int, next_actions: List[Optional[_Action]] - ) -> List[int]: + count: int, next_actions: list[Optional[_Action]] + ) -> list[int]: """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute.""" - seen: Set[int] = set() - ret: List[int] = [] + seen: set[int] = set() + ret: list[int] = [] for a in next_actions: if a is not None and a.stage_index not in seen: @@ -850,8 +840,8 @@ def _add_unshard_reshard( break return ret - active_stages: Set[int] = set() - fsdp_aware_actions: List[_Action] = [] + active_stages: set[int] = set() + fsdp_aware_actions: list[_Action] = [] def _unshard(stage_index: int): active_stages.add(stage_index) @@ -890,8 +880,8 @@ def _add_unshard_reshard( def _merge_bw( - compute_actions: List[Optional[_Action]], -) -> List[_Action]: + compute_actions: list[Optional[_Action]], +) -> list[_Action]: """Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops. (note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD) @@ -925,12 +915,12 @@ def _merge_bw( def _add_send_recv( - compute_actions: Dict[int, List[_Action]], + compute_actions: dict[int, list[_Action]], stage_to_rank: Callable[[int], int], num_stages: int, -) -> Dict[int, List[_Action]]: - comm_actions: Dict[int, List[_Action]] = {rank: [] for rank in compute_actions} - prev_actions: Dict[int, Set[_Action]] = {rank: set() for rank in compute_actions} +) -> dict[int, list[_Action]]: + comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions} + prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions} def _has_comms(action: _Action) -> bool: if action.computation_type == F: @@ -954,7 +944,7 @@ def _add_send_recv( return send, recv def _ready_to_schedule( - action: Optional[_Action], prev_actions: Set[_Action] + action: Optional[_Action], prev_actions: set[_Action] ) -> bool: """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place. This helps ensure a sane (non-hanging) ordering of sends and recvs. @@ -1030,7 +1020,7 @@ def _add_send_recv( def _validate_schedule( - actions: Dict[int, List[Optional[_Action]]], + actions: dict[int, list[Optional[_Action]]], pp_group_size: int, num_stages: int, num_microbatches: int, @@ -1043,7 +1033,7 @@ def _validate_schedule( # We will count all the actions per stage and ensure they happen in a valid order # (e.g. F before (B, I) before W for a given microbatch) - stage_actions: Dict[int, Dict[_ComputationType, Set]] = { + stage_actions: dict[int, dict[_ComputationType, set]] = { stage_id: { F: set(), B: set(), @@ -1108,13 +1098,13 @@ class PipelineScheduleMulti(_PipelineSchedule): def __init__( self, - stages: List[_PipelineStageBase], + stages: list[_PipelineStageBase], n_microbatches: int, loss_fn: Optional[Callable] = None, args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None, - stage_index_to_group_rank: Optional[Dict[int, int]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + stage_index_to_group_rank: Optional[dict[int, int]] = None, use_full_backward: Optional[bool] = None, scale_grads: bool = True, ): @@ -1148,7 +1138,7 @@ class PipelineScheduleMulti(_PipelineSchedule): self._should_compute_loss = lambda stage: stage.is_last and has_loss # This will be set during init of derived schedules - self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + self.pipeline_order: dict[int, list[Optional[_Action]]] = {} if use_full_backward is not None: logger.warning( @@ -1199,7 +1189,7 @@ class PipelineScheduleMulti(_PipelineSchedule): self._n_microbatches, ) - def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): + def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): """ Run one iteration of the pipeline schedule with *whole-batch* input. Will chunk the input into microbatches automatically, and go through the @@ -1235,10 +1225,10 @@ class PipelineScheduleMulti(_PipelineSchedule): def _step_microbatches( self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, ): """ Operate on the microbatches for looped schedules (multiple stages on each rank). @@ -1253,14 +1243,14 @@ class PipelineScheduleMulti(_PipelineSchedule): # Based on the plan in Step 1 created in __init__: # 2. Perform communication based on the pipeline_order - stage_index_to_stage: Dict[int, _PipelineStageBase] = { + stage_index_to_stage: dict[int, _PipelineStageBase] = { stage.stage_index: stage for stage in self._stages } # determine prev_rank and next_rank based on which ranks are next to # the stages in the pipeline_order - all_prev_ranks: Set[int] = set() - all_next_ranks: Set[int] = set() + all_prev_ranks: set[int] = set() + all_next_ranks: set[int] = set() for stage_index in stage_index_to_stage.keys(): # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections) if stage_index > 0: @@ -1271,7 +1261,7 @@ class PipelineScheduleMulti(_PipelineSchedule): backward_counter: Counter[int] = Counter() for time_step, action in enumerate(self.pipeline_order[self.rank]): try: - ops: List[dist.P2POp] = [] + ops: list[dist.P2POp] = [] if action is not None: computation_type = action.computation_type mb_index = action.microbatch_index @@ -1432,7 +1422,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): def _load_actions( self, - actions: Dict[int, List[Optional[_Action]]], + actions: dict[int, list[Optional[_Action]]], format: str = "compute_only", ): """ @@ -1442,7 +1432,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): assert ( self.stage_index_to_group_rank is not None ), "stage_index_to_group_rank is required for PipelineScheduleRuntime" - self.pipeline_order_with_comms: Dict[int, List[_Action]] = {} + self.pipeline_order_with_comms: dict[int, list[_Action]] = {} if format == "compute_comms": for rank in actions: self.pipeline_order_with_comms[rank] = [] @@ -1507,10 +1497,10 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): def _step_microbatches( self, - arg_mbs: Optional[List] = None, - kwarg_mbs: Optional[List] = None, - target_mbs: Optional[List] = None, - losses: Optional[List] = None, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, ): """ Operate on the microbatches for looped schedules (multiple stages on each rank). @@ -1524,7 +1514,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): # Based on the plan in Step 1 created in __init__: # 2. Perform communication based on the pipeline_order - stage_index_to_stage: Dict[int, _PipelineStageBase] = { + stage_index_to_stage: dict[int, _PipelineStageBase] = { stage.stage_index: stage for stage in self._stages } @@ -1533,14 +1523,14 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti): ), "Must call _load_actions() before calling _step_microbatches()" # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use - bwd_recv_ops: Dict[tuple[int, int], Work] = {} - fwd_recv_ops: Dict[tuple[int, int], Work] = {} + bwd_recv_ops: dict[tuple[int, int], Work] = {} + fwd_recv_ops: dict[tuple[int, int], Work] = {} # send ops should be waited on before step() exists, mainly for hygeine - send_ops: List[Work] = [] + send_ops: list[Work] = [] # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages - unshard_ops: Dict[int, UnshardHandle] = {} + unshard_ops: dict[int, UnshardHandle] = {} unsharded_stages = set() def _assert_unsharded(stage_idx: int): @@ -1751,10 +1741,10 @@ class ScheduleLoopedBFS(PipelineScheduleMulti): def __init__( self, - stages: List[_PipelineStageBase], + stages: list[_PipelineStageBase], n_microbatches: int, loss_fn: Optional[Union[Callable, _Loss]] = None, - output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, scale_grads: bool = True, ): super().__init__( @@ -1768,7 +1758,7 @@ class ScheduleLoopedBFS(PipelineScheduleMulti): # 1. Create the pipeline_order (all ranks do this calculation) # This will be used to keep track of the current state of the entire pipeline # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] - self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + self.pipeline_order: dict[int, list[Optional[_Action]]] = {} # ======================================================================== for rank in range(self.pp_group_size): rank_ops = self._calculate_single_rank_operations(rank) @@ -1782,7 +1772,7 @@ class ScheduleLoopedBFS(PipelineScheduleMulti): # Store the list of operations used for that rank # Pre-padding, rank starts with no-ops based on the warmup. - rank_ops: List[Optional[_Action]] = [None for _ in range(rank)] + rank_ops: list[Optional[_Action]] = [None for _ in range(rank)] for stage_index in stage_indices: rank_ops.extend( @@ -1816,13 +1806,13 @@ def _get_1f1b_rank_ops( enable_zero_bubble=False, ): # All stages start with handling microbatch 0 - fwd_stage_mb_index: Dict[int, int] = defaultdict(int) - bwd_stage_mb_index: Dict[int, int] = defaultdict(int) - weight_stage_mb_index: Dict[int, int] = defaultdict(int) + fwd_stage_mb_index: dict[int, int] = defaultdict(int) + bwd_stage_mb_index: dict[int, int] = defaultdict(int) + weight_stage_mb_index: dict[int, int] = defaultdict(int) # Store the list of operations used for that rank # Pre-padding, rank starts with no-ops based on the warmup. - rank_ops: List[Optional[_Action]] = [None for _ in range(rank)] + rank_ops: list[Optional[_Action]] = [None for _ in range(rank)] # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. # Formula: @@ -1960,12 +1950,12 @@ class ScheduleInterleaved1F1B(PipelineScheduleMulti): def __init__( self, - stages: List[_PipelineStageBase], + stages: list[_PipelineStageBase], n_microbatches: int, loss_fn: Optional[Callable] = None, args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, scale_grads: bool = True, ): self.pp_group_size = stages[0].group_size @@ -1991,12 +1981,12 @@ class ScheduleInterleaved1F1B(PipelineScheduleMulti): # 1. Create the pipeline_order (all ranks do this calculation) # This will be used to keep track of the current state of the entire pipeline # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] - self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + self.pipeline_order: dict[int, list[Optional[_Action]]] = {} for rank in range(self.pp_group_size): rank_ops = self._calculate_single_rank_operations(rank) self.pipeline_order[rank] = rank_ops - def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: + def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]: def get_rank_warmup_ops(rank): # Warms up operations for last stage warmups_ops_last_stage = ( @@ -2069,12 +2059,12 @@ class ScheduleInterleavedZeroBubble(PipelineScheduleMulti): def __init__( self, - stages: List[_PipelineStageBase], + stages: list[_PipelineStageBase], n_microbatches: int, loss_fn: Optional[Callable] = None, args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, scale_grads: bool = True, ): # TODO: we don't support Zero Bubble with torch.compile so we @@ -2109,7 +2099,7 @@ stage modules that have used torch.compile" # 1. Create the pipeline_order (all ranks do this calculation) # This will be used to keep track of the current state of the entire pipeline # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] - self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + self.pipeline_order: dict[int, list[Optional[_Action]]] = {} for rank in range(self.pp_group_size): rank_ops = self._calculate_single_rank_operations(rank) self.pipeline_order[rank] = rank_ops @@ -2121,7 +2111,7 @@ stage modules that have used torch.compile" self.n_local_stages * self.pp_group_size, ) - def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: + def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]: def get_rank_warmup_ops(rank): # Warms up operations for last stage warmups_ops_last_stage = ( @@ -2198,10 +2188,10 @@ stage modules that have used torch.compile" return (stage + 1, op, microbatch) not in seen_ops return False - seen_ops: Set[tuple[int, _ComputationType, int]] = set() - result: Dict[int, List[Optional[_Action]]] = {} - next_pointer: Dict[int, int] = {} - bubbles_added: Dict[int, int] = {} + seen_ops: set[tuple[int, _ComputationType, int]] = set() + result: dict[int, list[Optional[_Action]]] = {} + next_pointer: dict[int, int] = {} + bubbles_added: dict[int, int] = {} total_bubbles_added = 0 for rank in range(self.pp_group_size): @@ -2212,7 +2202,7 @@ stage modules that have used torch.compile" while True: should_stop = True - temp_seen_ops: Set[tuple[int, _ComputationType, int]] = set() + temp_seen_ops: set[tuple[int, _ComputationType, int]] = set() for rank in range(self.pp_group_size): timestamp = next_pointer[rank] @@ -2270,13 +2260,13 @@ class ScheduleZBVZeroBubble(PipelineScheduleMulti): def __init__( self, - stages: List[_PipelineStageBase], + stages: list[_PipelineStageBase], n_microbatches: int, loss_fn: Optional[Callable] = None, args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None, - stage_index_to_group_rank: Optional[Dict[int, int]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + stage_index_to_group_rank: Optional[dict[int, int]] = None, scale_grads: bool = True, ): self.pp_group_size = stages[0].group_size @@ -2303,16 +2293,16 @@ class ScheduleZBVZeroBubble(PipelineScheduleMulti): # 1. Create the pipeline_order (all ranks do this calculation) # This will be used to keep track of the current state of the entire pipeline # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] - self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} + self.pipeline_order: dict[int, list[Optional[_Action]]] = {} for rank in range(self.pp_group_size): rank_ops = self._calculate_single_rank_operations(rank) self.pipeline_order[rank] = rank_ops - def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: + def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]: # max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least # as large of the number of microbatches needed to fully utilize the pipeline n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches) - rank_ops: List[Optional[_Action]] = [None for _ in range(rank)] + rank_ops: list[Optional[_Action]] = [None for _ in range(rank)] # Forward and backward action counts for stage chunk 0 and chunk 1 f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0 @@ -2469,11 +2459,11 @@ def _simulate_comms_compute( rank: [a for a in pipeline_order[rank] if a is not None] for rank in sorted(pipeline_order) } - _schedule: Dict[int, List[_Action | None]] = { + _schedule: dict[int, list[_Action | None]] = { rank: [] for rank in sorted(pipeline_order) } - _prev_ops_rank: Dict[int, Set[_Action]] = {rank: set() for rank in _schedule} + _prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule} def add_to_schedule(rank: int, action: Optional[_Action]): _schedule[rank].append(action) diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index 0e478a41847..01d77ab9a19 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -3,7 +3,7 @@ import logging import operator from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist @@ -160,10 +160,10 @@ class _PipelineStageBase(ABC): self.dw_builder = dw_builder # backward state - self.backward_state: Dict[int, tuple[Any, ...]] = {} + self.backward_state: dict[int, tuple[Any, ...]] = {} # store dw_runner per microbatch_id - self.dw_runner: Dict[int, Callable[..., None]] = {} + self.dw_runner: dict[int, Callable[..., None]] = {} # `group_rank` is rank in process group `group`. self.group_rank = dist.get_rank(self.group) @@ -176,11 +176,11 @@ class _PipelineStageBase(ABC): # Run time states self._outputs_meta: Optional[tuple[torch.Tensor, ...]] = None # map microbatch ID to list of forward tensor args - self.fwd_cache: Dict[int, tuple[Any, List[torch.Tensor]]] = {} + self.fwd_cache: dict[int, tuple[Any, list[torch.Tensor]]] = {} # map microbatch ID to list of backward grad tensor args - self.bwd_cache: Dict[int, tuple[Optional[torch.Tensor], ...]] = {} + self.bwd_cache: dict[int, tuple[Optional[torch.Tensor], ...]] = {} # Caching chunk outputs for final output merge or reduction - self.output_chunks: List[Any] = [] + self.output_chunks: list[Any] = [] # Initialize has_backward to false; this will be set to true if loss # function is passed to pipeline schedule @@ -189,16 +189,16 @@ class _PipelineStageBase(ABC): self.log_prefix = f"[Stage {self.stage_index}]" # Forward infra - self.args_recv_info: Dict[int, tuple[InputInfo, ...]] = {} - self.act_send_info: Dict[int, List] = {} + self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {} + self.act_send_info: dict[int, list] = {} # Backward infra will created lazily - self.grad_recv_info: Dict = {} - self.grad_send_info: Optional[List] = None + self.grad_recv_info: dict = {} + self.grad_send_info: Optional[list] = None # To be populated later by the Schedule self.chunks: Optional[int] = None - self.stage_index_to_group_rank: Dict[int, int] = { + self.stage_index_to_group_rank: dict[int, int] = { i: i % self.group_size for i in range(self.num_stages) } @@ -258,12 +258,12 @@ class _PipelineStageBase(ABC): def _create_grad_send_info( self, - args_recv_info: Tuple, - ) -> List[Optional[int]]: + args_recv_info: tuple, + ) -> list[Optional[int]]: """ Create a list of stage indices to send gradients to. """ - grad_send_info: List[Optional[int]] = [] + grad_send_info: list[Optional[int]] = [] def map_recv_to_send(a): # Note: we send gradients back to previous stage as long as in @@ -286,7 +286,7 @@ class _PipelineStageBase(ABC): self, num_microbatches: int, args: tuple[Any, ...], - kwargs: Optional[Dict[str, Any]] = None, + kwargs: Optional[dict[str, Any]] = None, ) -> tuple[Any, ...]: raise NotImplementedError @@ -303,19 +303,19 @@ class _PipelineStageBase(ABC): @abstractmethod def _create_grad_recv_info( self, - act_send_info: Dict, + act_send_info: dict, ) -> tuple[_RecvInfo, ...]: raise NotImplementedError def _get_recv_ops( self, recv_infos: tuple[InputInfo, ...], - ) -> List[dist.P2POp]: + ) -> list[dist.P2POp]: """ Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`. Returns a list of ops that correspond to the recv infos. """ - ops: List[dist.P2POp] = [] + ops: list[dist.P2POp] = [] for info in recv_infos: if not isinstance(info, _RecvInfo): continue @@ -410,7 +410,7 @@ class _PipelineStageBase(ABC): ), f"Expected a recv info, got {type(info)}" info.buffer = tensor - def get_fwd_recv_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: + def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: """ Returns a list of ops that are needed to receive the input arguments for this stage. @@ -419,7 +419,7 @@ class _PipelineStageBase(ABC): return self._get_recv_ops(recv_infos) - def get_bwd_recv_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: + def get_bwd_recv_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: """ Returns a list of ops that are needed to receive the gradients for this stage. @@ -430,7 +430,7 @@ class _PipelineStageBase(ABC): recv_infos = self.grad_recv_info[bwd_chunk_id] return self._get_recv_ops(recv_infos) - def get_fwd_send_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]: + def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: """ Get the activation send ops for current stage's forward. """ @@ -439,7 +439,7 @@ class _PipelineStageBase(ABC): # `act_send_info` output_tuple = output if type(output) is tuple else (output,) - ops: List[dist.P2POp] = [] + ops: list[dist.P2POp] = [] for idx, out in enumerate(output_tuple): dst_stages = self.act_send_info[idx] @@ -462,7 +462,7 @@ class _PipelineStageBase(ABC): return ops - def get_bwd_send_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]: + def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: """ Get the gradient send ops for current stage's backward. """ @@ -479,7 +479,7 @@ class _PipelineStageBase(ABC): # `grad_send_info` is a mirror of `args_recv_info` self.grad_send_info = self._create_grad_send_info(self.args_recv_info[0]) - ops: List[dist.P2POp] = [] + ops: list[dist.P2POp] = [] grads_input = self.bwd_cache.pop(bwd_chunk_id) for grad, grad_recv_stage in zip(grads_input, self.grad_send_info): if isinstance(grad, torch.Tensor) and grad_recv_stage is not None: @@ -593,9 +593,9 @@ class _PipelineStageBase(ABC): def backward_maybe_with_nosync( self, backward_type, - bwd_kwargs: Dict, + bwd_kwargs: dict, last_backward: bool = False, - ) -> tuple[tuple[Optional[torch.Tensor], ...], Optional[List[Dict[str, Any]]]]: + ) -> tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]]: """ Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but @@ -607,7 +607,7 @@ class _PipelineStageBase(ABC): backward_type, ) -> Callable[ [], - tuple[tuple[Optional[torch.Tensor], ...], Optional[List[Dict[str, Any]]]], + tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]], ]: if backward_type == "full": return lambda: ( @@ -686,7 +686,7 @@ class _PipelineStageBase(ABC): self, fwd_chunk_id: int, args: tuple[Any, ...], - kwargs: Optional[Dict[str, Any]] = None, + kwargs: Optional[dict[str, Any]] = None, ): """ Perform forward pass on the stage with one microbatch. @@ -817,7 +817,7 @@ class _PipelineStageBase(ABC): "full", bwd_kwargs, last_backward=last_backward ) else: - param_groups: List[Dict[str, Any]] | None = None + param_groups: list[dict[str, Any]] | None = None # Skip the backward for the first stage since we will perform the weight update with # autograd.backward in backward_weight_one_chunk if not self.is_first: @@ -986,7 +986,7 @@ class _PipelineStage(_PipelineStageBase): ) # Create mapping from stage name to stage index - self.submod_to_stage_index: Dict[str, int] = {} + self.submod_to_stage_index: dict[str, int] = {} for i, node in enumerate(submod_nodes): self.submod_to_stage_index.setdefault(node.name, i) @@ -1010,7 +1010,7 @@ class _PipelineStage(_PipelineStageBase): self, num_microbatches: int, args: tuple[Any, ...], - kwargs: Optional[Dict[str, Any]] = None, + kwargs: Optional[dict[str, Any]] = None, ) -> tuple[Any, ...]: """ Create send/recv infrastructures for activations (during forward) @@ -1084,7 +1084,7 @@ class _PipelineStage(_PipelineStageBase): buffer, ) - args_recv_info: List[InputInfo] = [] + args_recv_info: list[InputInfo] = [] # Filter out placeholder nodes from `self.submod` (a GraphModule) placeholders = filter( # type: ignore[var-annotated] lambda node: node.op == "placeholder", self.submod.graph.nodes # type: ignore[arg-type, union-attr] @@ -1134,7 +1134,7 @@ class _PipelineStage(_PipelineStageBase): be consumed by multiple stages. """ # Output index: List of receiver ranks - act_send_info: Dict[int, List] = {} + act_send_info: dict[int, list] = {} out_idx = 0 for user in self.node.users: @@ -1171,13 +1171,13 @@ class _PipelineStage(_PipelineStageBase): def _create_grad_recv_info( self, - act_send_info: Dict, + act_send_info: dict, ) -> tuple[_RecvInfo, ...]: """ Create a tuple of `_RecvInfo` for gradients. """ # Dict[output_index, _RecvInfo] - grad_recv_info: Dict[int, _RecvInfo] = {} + grad_recv_info: dict[int, _RecvInfo] = {} output_node = self._get_output_node() # The output node may take multiple args, meaning the submod having multiple output values. @@ -1275,7 +1275,7 @@ class PipelineStage(_PipelineStageBase): dw_builder: Optional[Callable[[], Callable[..., None]]] = None, ): super().__init__(submodule, stage_index, num_stages, device, group, dw_builder) - self.inputs: Optional[List[torch.Tensor]] = None + self.inputs: Optional[list[torch.Tensor]] = None self.inputs_meta: Optional[tuple[torch.Tensor, ...]] = None # Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) becuase it # might be breaking for existing users. @@ -1313,7 +1313,7 @@ class PipelineStage(_PipelineStageBase): ) # these are the buffers used in backwards send/recv, they are allocated later - self.outputs_grad: List[torch.Tensor] = [] + self.outputs_grad: list[torch.Tensor] = [] def stage_global_rank(peer_rank): return ( @@ -1342,7 +1342,7 @@ class PipelineStage(_PipelineStageBase): def _shape_inference( self, args: tuple[Any, ...], - kwargs: Optional[Dict[str, Any]] = None, + kwargs: Optional[dict[str, Any]] = None, ): if kwargs is None: kwargs = {} @@ -1443,7 +1443,7 @@ class PipelineStage(_PipelineStageBase): self, num_microbatches: int, args: tuple[Any, ...], - kwargs: Optional[Dict[str, Any]] = None, + kwargs: Optional[dict[str, Any]] = None, ) -> tuple[Any, ...]: # TODO move self.device to an argument from step API (from its input tensors)? assert num_microbatches is not None, "TODO fix num_microbatches" @@ -1481,7 +1481,7 @@ class PipelineStage(_PipelineStageBase): # Send info during forward for each activation # only need the rank that is being sent to - self.act_send_info: Dict[int, List] = {} + self.act_send_info: dict[int, list] = {} for idx in range(len(self.get_outputs_meta())): # We assume we always send to stage + 1 @@ -1494,7 +1494,7 @@ class PipelineStage(_PipelineStageBase): def _create_grad_recv_info( self, - act_send_info: Dict, + act_send_info: dict, ) -> tuple[_RecvInfo, ...]: grad_recv_info: tuple[_RecvInfo, ...] = () if not self.is_last: diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 6dcccacf928..d5169c58161 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -9,15 +9,16 @@ except ImportError as e: import numbers import os import sys +from collections.abc import Iterator from datetime import timedelta -from typing import Callable, Dict, Iterator, Optional +from typing import Callable, Optional from torch.distributed import FileStore, Store, TCPStore from .constants import default_pg_timeout -_rendezvous_handlers: Dict[str, Callable[..., Iterator[tuple[Store, int, int]]]] = {} +_rendezvous_handlers: dict[str, Callable[..., Iterator[tuple[Store, int, int]]]] = {} __all__ = ["register_rendezvous_handler", "rendezvous"] @@ -54,14 +55,14 @@ def register_rendezvous_handler(scheme, handler): # Query will have format "rank=0&world_size=1" and is # converted into {"rank": 0, "world_size": 1} -def _query_to_dict(query: str) -> Dict[str, str]: +def _query_to_dict(query: str) -> dict[str, str]: return { pair[0]: pair[1] for pair in (pair.split("=") for pair in filter(None, query.split("&"))) } -def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool: +def _get_use_libuv_from_query_dict(query_dict: dict[str, str]) -> bool: # libuv is the default backend for TCPStore. To enable the non-libuv backend, # user can explicitly specify ``use_libuv=0`` in the URL parameter. return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1" diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index 2819b9312f4..a17185f520a 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -3,8 +3,9 @@ import logging import os import threading import warnings +from collections.abc import Generator from datetime import timedelta -from typing import Generator, Tuple +from typing import Tuple from urllib.parse import urlparse import torch diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index fbe77e46574..164ba4056ee 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -7,7 +7,7 @@ import functools import inspect import logging import threading -from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar +from typing import Any, Generic, TYPE_CHECKING, TypeVar import torch from torch._C._distributed_rpc import ( @@ -115,9 +115,9 @@ class AllGatherStates: # States used by `def _all_gather()`. # `_ALL_WORKER_NAMES` is initialized on initializing RPC layer. -_ALL_WORKER_NAMES: Set[Any] = set() +_ALL_WORKER_NAMES: set[Any] = set() _all_gather_dict_lock = threading.RLock() -_all_gather_sequence_id: Dict[str, int] = {} +_all_gather_sequence_id: dict[str, int] = {} _all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict( AllGatherStates ) diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index 5e15767c0b0..5a8a026c918 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -3,7 +3,7 @@ import collections import enum -from typing import cast, Dict, List, Set +from typing import cast import torch import torch.distributed as dist @@ -163,8 +163,8 @@ def _tensorpipe_validate_devices(devices, device_count): def _tensorpipe_exchange_and_check_all_device_maps( my_name, my_device_count, my_device_maps, my_devices, group ): - gathered: List[ - tuple[str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]] + gathered: list[ + tuple[str, int, dict[str, dict[torch.device, torch.device]], list[torch.device]] ] = [("", 0, {}, []) for _ in range(group.size())] dist.all_gather_object( gathered, (my_name, my_device_count, my_device_maps, my_devices), group @@ -253,7 +253,7 @@ def _validate_device_maps( def _create_device_list(my_devices, my_device_maps, reverse_device_maps): if not my_devices: - devices_set: Set[torch.device] = set() + devices_set: set[torch.device] = set() for map_ in my_device_maps.values(): devices_set.update(map_.keys()) for map_ in reverse_device_maps.values(): @@ -265,7 +265,7 @@ def _create_device_list(my_devices, my_device_maps, reverse_device_maps): def _create_reverse_mapping(my_name, all_names, all_device_maps): - reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {} + reverse_device_maps: dict[str, dict[torch.device, torch.device]] = {} for node in all_names: if my_name in all_device_maps[node]: reverse_device_maps[node] = { diff --git a/torch/distributed/rpc/constants.py b/torch/distributed/rpc/constants.py index 56f6db4db25..f0eaf92b8ae 100644 --- a/torch/distributed/rpc/constants.py +++ b/torch/distributed/rpc/constants.py @@ -1,5 +1,4 @@ from datetime import timedelta -from typing import List from torch._C._distributed_rpc import ( _DEFAULT_INIT_METHOD, @@ -22,4 +21,4 @@ DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2**31 - 1) # Value indicating that timeout is not set for RPC call, and the default should be used. UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT -__all__: List[str] = [] +__all__: list[str] = [] diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 53bf473ba56..2be42a38ee2 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Dict, List, Optional, Union +from typing import Optional, Union import torch from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase @@ -23,10 +23,10 @@ def _to_device(device: DeviceType) -> torch.device: def _to_device_map( - device_map: Dict[DeviceType, DeviceType] -) -> Dict[torch.device, torch.device]: - full_device_map: Dict[torch.device, torch.device] = {} - reverse_map: Dict[torch.device, torch.device] = {} + device_map: dict[DeviceType, DeviceType] +) -> dict[torch.device, torch.device]: + full_device_map: dict[torch.device, torch.device] = {} + reverse_map: dict[torch.device, torch.device] = {} for k, v in device_map.items(): k, v = torch.device(k), torch.device(v) if v in reverse_map: @@ -39,7 +39,7 @@ def _to_device_map( return full_device_map -def _to_device_list(devices: List[DeviceType]) -> List[torch.device]: +def _to_device_list(devices: list[DeviceType]) -> list[torch.device]: return list(map(_to_device, devices)) @@ -83,10 +83,10 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS, rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC, init_method: str = rpc_contants.DEFAULT_INIT_METHOD, - device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None, - devices: Optional[List[DeviceType]] = None, - _transports: Optional[List] = None, - _channels: Optional[List] = None, + device_maps: Optional[dict[str, dict[DeviceType, DeviceType]]] = None, + devices: Optional[list[DeviceType]] = None, + _transports: Optional[list] = None, + _channels: Optional[list] = None, ): full_device_maps = ( {} @@ -104,7 +104,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): full_device_list, ) - def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]): + def set_device_map(self, to: str, device_map: dict[DeviceType, DeviceType]): r""" Set device mapping between each RPC caller and callee pair. This function can be called multiple times to incrementally add @@ -162,7 +162,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): super()._set_device_map(to, full_device_map) - def set_devices(self, devices: List[DeviceType]): + def set_devices(self, devices: list[DeviceType]): r""" Set local devices used by the TensorPipe RPC agent. When processing CUDA RPC requests, the TensorPipe RPC agent will properly synchronize diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index 46696599032..7b9a4d0bcde 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -2,7 +2,6 @@ # mypy: allow-untyped-defs import itertools -from typing import List import torch from torch.autograd.profiler_legacy import profile @@ -13,7 +12,7 @@ from . import ( ) -__all__: List[str] = [] +__all__: list[str] = [] class _server_process_global_profile(profile): diff --git a/torch/distributed/run.py b/torch/distributed/run.py index b227b9ff129..de2f5a2f5d7 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -398,7 +398,7 @@ import sys import uuid from argparse import ArgumentParser, REMAINDER from importlib import metadata -from typing import Callable, List, Optional, Set, Type, Union +from typing import Callable, Optional, Union import torch from torch.distributed.argparse_util import check_env, env @@ -736,7 +736,7 @@ def get_use_env(args) -> bool: return args.use_env -def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]: +def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]: """ Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param. Provides plugin mechanism to provide custom implementation of LogsSpecs. @@ -770,7 +770,7 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]: return logs_specs_cls -def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str]]: +def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]: # If ``args`` not passed, defaults to ``sys.argv[:1]`` min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) assert 0 < min_nodes <= max_nodes @@ -810,7 +810,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str rdzv_endpoint = get_rdzv_endpoint(args) - ranks: Optional[Set[int]] = None + ranks: Optional[set[int]] = None if args.local_ranks_filter: try: ranks = set(map(int, args.local_ranks_filter.split(","))) @@ -820,7 +820,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2" ) from e - logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs) + logs_specs_cls: type[LogsSpecs] = _get_logs_specs_class(args.logs_specs) logs_specs = logs_specs_cls( log_dir=args.log_dir, redirects=Std.from_str(args.redirects), diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index 47d811de99f..3e608778e5a 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -1,18 +1,9 @@ # mypy: allow-untyped-defs import dataclasses import traceback -from typing import ( - Any, - Callable, - Container, - Dict, - List, - Optional, - OrderedDict, - overload, - Set, - TypeVar, -) +from collections import OrderedDict +from collections.abc import Container +from typing import Any, Callable, Optional, overload, TypeVar import torch import torch.distributed as dist @@ -43,8 +34,8 @@ def _pack_kwargs(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], tuple[str, kwarg keys. The second tuple element gives the kwarg keys. The second tuple element's length is at most the first tuple element's length. """ - kwarg_keys: List[str] = [] - flat_args: List[Any] = list(args) + kwarg_keys: list[str] = [] + flat_args: list[Any] = list(args) for k, v in kwargs.items(): kwarg_keys.append(k) flat_args.append(v) @@ -75,7 +66,7 @@ def _cast_forward_inputs( def _unpack_kwargs( flat_args: tuple[Any, ...], kwarg_keys: tuple[str, ...] -) -> tuple[tuple[Any, ...], Dict[str, Any]]: +) -> tuple[tuple[Any, ...], dict[str, Any]]: """See _pack_kwargs.""" assert len(kwarg_keys) <= len( flat_args @@ -94,7 +85,7 @@ T = TypeVar("T", torch.Tensor, PackedSequence) @overload def _recursive_to( inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool -) -> List[S]: +) -> list[S]: ... @@ -264,10 +255,10 @@ def _apply_to_tensors(fn, container): def _to_kwargs( inputs: tuple[Any, ...], - kwargs: Optional[Dict[str, Any]], + kwargs: Optional[dict[str, Any]], target_device: torch.device, use_side_stream_for_tensor_copies: bool, -) -> tuple[tuple[Any, ...], tuple[Dict[str, Any], ...]]: +) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]: moved_inputs = ( _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies) if inputs @@ -287,7 +278,7 @@ def _to_kwargs( def _verify_param_shape_across_processes( process_group: dist.ProcessGroup, - tensors: List[torch.Tensor], + tensors: list[torch.Tensor], logger: Optional["dist.Logger"] = None, ): return dist._verify_params_across_processes(process_group, tensors, logger) @@ -309,7 +300,7 @@ def _sync_module_states( parameter shapes are consistent before running the synchronization. This can be checked with ``_verify_param_shape_across_processes``. """ - module_states: List[torch.Tensor] = [] + module_states: list[torch.Tensor] = [] for name, param in module.named_parameters(): if name not in params_and_buffers_to_ignore: module_states.append(param.detach()) @@ -324,7 +315,7 @@ def _sync_module_states( def _sync_params_and_buffers( process_group: dist.ProcessGroup, - module_states: List[torch.Tensor], + module_states: list[torch.Tensor], broadcast_bucket_size: int, src: int, ) -> None: @@ -336,7 +327,7 @@ def _sync_params_and_buffers( def _replace_by_prefix( - state_dict: Dict[str, Any], + state_dict: dict[str, Any], old_prefix: str, new_prefix: str, ) -> None: @@ -363,15 +354,15 @@ def _data_ptr_allocated(tensor: torch.Tensor) -> bool: return tensor.untyped_storage().data_ptr() > 0 -def _get_root_modules(modules: List[nn.Module]) -> List[nn.Module]: +def _get_root_modules(modules: list[nn.Module]) -> list[nn.Module]: """ Returns the modules in ``modules`` that are root modules (i.e. parent-less) with respect to the set ``modules``. In other words, these are the modules in ``modules`` that are the not child of any other module in ``modules``. """ - root_modules: List[nn.Module] = [] - module_to_modules: Dict[nn.Module, Set[nn.Module]] = { + root_modules: list[nn.Module] = [] + module_to_modules: dict[nn.Module, set[nn.Module]] = { module: set(module.modules()) for module in modules } for candidate_module in modules: