PEP585 update - torch/distributed (#145164)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein 2025-01-20 14:50:01 -08:00 committed by PyTorch MergeBot
parent c6986ca2e1
commit 00ffeca1b1
79 changed files with 805 additions and 860 deletions

View File

@ -79,7 +79,7 @@ if is_available():
finally:
sys.stdin = _stdin
_breakpoint_cache: typing.Dict[int, typing.Any] = {}
_breakpoint_cache: dict[int, typing.Any] = {}
def breakpoint(rank: int = 0, skip: int = 0):
"""

View File

@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import List, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
import torch
@ -12,7 +12,7 @@ class _Checkpointable(Protocol): # noqa: PYI046
This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface.
"""
def __create_write_items__(self, fqn: str, object: object) -> List[object]:
def __create_write_items__(self, fqn: str, object: object) -> list[object]:
"""
Return a list of WriteItems based on object's contents.
"""
@ -20,7 +20,7 @@ class _Checkpointable(Protocol): # noqa: PYI046
"_Checkpointable._create_write_items is not implemented"
)
def __create_chunk_list__(self) -> List[object]:
def __create_chunk_list__(self) -> list[object]:
"""
Return a list of `ChunkStorageMetadata` based on object's contents.
"""

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
from collections.abc import Generator
from contextlib import contextmanager, nullcontext
from typing import Any, ContextManager, Dict, Generator, Optional
from typing import Any, ContextManager, Optional
import torch
import torch.nn as nn
@ -85,7 +86,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
)
def forward_pre_hook(
module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any]
module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
if checkpoint.state(module).enable_hook:

View File

@ -1,20 +1,9 @@
# mypy: allow-untyped-defs
import uuid
from collections import OrderedDict
from collections.abc import Sequence
from functools import wraps
from typing import (
Callable,
Dict,
Generic,
List,
Optional,
overload,
Protocol,
Sequence,
Type,
TypeVar,
Union,
)
from typing import Callable, Generic, Optional, overload, Protocol, TypeVar, Union
from typing_extensions import Concatenate, ParamSpec
import torch
@ -66,7 +55,7 @@ def contract() -> (
@overload
def contract(
state_cls: Type[_TState],
state_cls: type[_TState],
) -> Callable[
[Callable[Concatenate[nn.Module, _P], Optional[nn.Module]]],
_ContractFn[Concatenate[nn.Module, _P], _T, _TState],
@ -75,7 +64,7 @@ def contract(
def contract(
state_cls: Type = _State,
state_cls: type = _State,
) -> Callable[
[
Callable[
@ -153,21 +142,21 @@ def contract(
# `func` is allowed to return different module instances than the
# input modules as long as FQNs are preserved following the input
# module order
all_orig_named_params: List[Dict[str, nn.Parameter]] = []
all_orig_named_buffers: List[Dict[str, torch.Tensor]] = []
all_orig_named_modules: List[Dict[str, nn.Module]] = []
all_orig_named_params: list[dict[str, nn.Parameter]] = []
all_orig_named_buffers: list[dict[str, torch.Tensor]] = []
all_orig_named_modules: list[dict[str, nn.Module]] = []
for module in modules:
default_all_state: Dict[Callable, _State] = OrderedDict()
default_registry: Dict[str, RegistryItem] = OrderedDict()
all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
default_all_state: dict[Callable, _State] = OrderedDict()
default_registry: dict[str, RegistryItem] = OrderedDict()
all_state: dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
STATE_KEY, default_all_state
)
if not isinstance(all_state, dict):
raise AssertionError(
f"Distributed composable API states corrupted: {all_state}"
)
registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
registry: dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
REGISTRY_KEY, default_registry
)
if not isinstance(registry, dict):
@ -195,9 +184,9 @@ def contract(
else:
updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type]
all_new_named_params: List[Dict[str, nn.Parameter]] = []
all_new_named_buffers: List[Dict[str, torch.Tensor]] = []
all_new_named_modules: List[Dict[str, nn.Module]] = []
all_new_named_params: list[dict[str, nn.Parameter]] = []
all_new_named_buffers: list[dict[str, torch.Tensor]] = []
all_new_named_modules: list[dict[str, nn.Module]] = []
for module in updated_modules:
all_new_named_params.append(OrderedDict(module.named_parameters()))
all_new_named_buffers.append(OrderedDict(module.named_buffers()))
@ -212,7 +201,7 @@ def contract(
f"Outputs: {num_new_modules} modules"
)
def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str):
def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str):
if orig_fqns == new_fqns:
return
@ -280,7 +269,7 @@ def contract(
return inner
def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]:
def _get_registry(module: nn.Module) -> Optional[dict[str, RegistryItem]]:
r"""
Get an ``OrderedDict`` of composable APIs that have been applied to the
``module``, indexed by the API name. If no API has been applied, then this

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import weakref
from typing import Any, Dict, Iterable, List, NoReturn, Optional, Set
from collections.abc import Iterable
from typing import Any, NoReturn, Optional
import torch
import torch.nn as nn
@ -24,17 +25,17 @@ class _ReplicateState(_State):
# TODO(@fegin): this variable is originally create for testing, we
# should remove this if possible.
self._orig_module = self.module
self._param_names: List[str] = []
self._param_names: list[str] = []
self._no_sync: bool = False
self._init_args: Optional[tuple[Any, ...]] = None
self._init_kwargs: Dict[str, Any] = {}
self._comm_hook_args: List[Any] = []
self._init_kwargs: dict[str, Any] = {}
self._comm_hook_args: list[Any] = []
def _collect_params(
self,
module: nn.Module,
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
ignored_modules: set[nn.Module],
ignored_params: set[nn.Parameter],
prefix: str = _ROOT_MODULE_PREFIX,
) -> None:
# skip if managed by fully_sharded API
@ -76,7 +77,7 @@ class _ReplicateState(_State):
def init(
self,
module: nn.Module,
ignored_modules: Set[nn.Module],
ignored_modules: set[nn.Module],
**kwargs,
) -> None:
if self.has_initialized:
@ -125,7 +126,7 @@ class _ReplicateState(_State):
self._init_kwargs = kwargs
def forward_pre_hook(
self, module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any]
self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> Any:
if self._init_args or self._init_kwargs:
self.lazy_init()

View File

@ -2,7 +2,7 @@
import contextlib
import sys
import warnings
from typing import Any, cast, List, Optional, Type, TYPE_CHECKING, Union
from typing import Any, cast, List, Optional, TYPE_CHECKING, Union
import torch
import torch.distributed as dist
@ -94,8 +94,8 @@ Functional collectives can accept any of these types to describe the ranks parti
The different types will be desugared to a canonical format
"""
RANK_TYPES = Union[
List[int],
List[List[int]],
list[int],
list[list[int]],
dist.ProcessGroup,
DeviceMesh,
tuple["dist.tensor.DeviceMesh", int],
@ -331,8 +331,8 @@ def reduce_scatter_tensor_autograd(
def all_reduce_coalesced(
self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""
) -> List[torch.Tensor]:
self: list[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""
) -> list[torch.Tensor]:
"""
Reduces a list of tensors across all machines in such a way that all get
the final result.
@ -359,8 +359,8 @@ def all_reduce_coalesced(
def all_gather_into_tensor_coalesced(
self: List[torch.Tensor], group: RANK_TYPES, tag: str = ""
) -> List[torch.Tensor]:
self: list[torch.Tensor], group: RANK_TYPES, tag: str = ""
) -> list[torch.Tensor]:
"""
Gather a list of tensors across from all machines.
@ -388,12 +388,12 @@ def all_gather_into_tensor_coalesced(
def reduce_scatter_tensor_coalesced(
inputs: List[torch.Tensor],
inputs: list[torch.Tensor],
reduceOp: str,
scatter_dim: List[int],
scatter_dim: list[int],
group: RANK_TYPES,
tag: str = "",
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
"""
Reduces a list of tensors across all machines in such a way that all get
the final result, then scatter the results to corresponding ranks.
@ -450,8 +450,8 @@ def _is_view_op(tgt):
def all_to_all_single(
self: torch.Tensor,
output_split_sizes: Optional[List[int]],
input_split_sizes: Optional[List[int]],
output_split_sizes: Optional[list[int]],
input_split_sizes: Optional[list[int]],
group: RANK_TYPES,
tag: str = "",
) -> torch.Tensor:
@ -498,8 +498,8 @@ def all_to_all_single(
def all_to_all_single_autograd(
self: torch.Tensor,
output_split_sizes: Optional[List[int]],
input_split_sizes: Optional[List[int]],
output_split_sizes: Optional[list[int]],
input_split_sizes: Optional[list[int]],
group: RANK_TYPES,
tag: str = "",
) -> torch.Tensor:
@ -535,7 +535,7 @@ def all_to_all_single_autograd(
def permute_tensor(
self: torch.Tensor,
src_dst: List[int],
src_dst: list[int],
group: RANK_TYPES,
tag: str = "",
) -> torch.Tensor:
@ -608,7 +608,7 @@ class AsyncCollectiveTensor(torch.Tensor):
return AsyncCollectiveTensor(elem)
def __coerce_same_metadata_as_tangent__(
self, expected_metadata: Any, expected_type: Optional[Type] = None
self, expected_metadata: Any, expected_type: Optional[type] = None
):
if expected_type is not torch.Tensor:
return None
@ -677,7 +677,7 @@ Utils and infrastructure for tracing support
"""
def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, List[int], int]:
def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int]:
"""
_expand_group desugars the different RANK_TYPES types into a canonical format that is traceable.
@ -690,10 +690,10 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, List[int], int
if TYPE_CHECKING:
def cast_listlistint(x):
return cast(List[List[int]], x)
return cast(list[list[int]], x)
def cast_listint(x):
return cast(List[int], x)
return cast(list[int], x)
else:
# fake cast op for use at runtime since dynamo doesn't support real cast
@ -705,7 +705,7 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, List[int], int
def cast_listint(x):
return x
rankset: List[int]
rankset: list[int]
if isinstance(group, list):
if isinstance(group[0], list):
nested_list = cast_listlistint(group)
@ -1143,7 +1143,7 @@ def all_to_all_inplace(
def all_gather_inplace(
tensor_list: List[torch.Tensor],
tensor_list: list[torch.Tensor],
tensor: torch.Tensor,
group=None,
async_op=False,

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import List, Optional
from typing import Optional
import torch
import torch.distributed.distributed_c10d as c10d
@ -60,7 +60,7 @@ def _reduce_scatter_tensor(
input: torch.Tensor,
reduce_op: str,
tag: str,
ranks: List[int],
ranks: list[int],
group_size: int,
):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
@ -73,10 +73,10 @@ def _reduce_scatter_tensor(
def _reduce_scatter_tensor_coalesced(
inputs: List[torch.Tensor],
inputs: list[torch.Tensor],
reduce_op: str,
tag: str,
ranks: List[int],
ranks: list[int],
group_size: int,
):
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
@ -90,10 +90,10 @@ def _reduce_scatter_tensor_coalesced(
def _all_to_all_single(
input: torch.Tensor,
output_split_sizes: Optional[List[int]],
input_split_sizes: Optional[List[int]],
output_split_sizes: Optional[list[int]],
input_split_sizes: Optional[list[int]],
tag: str,
ranks: List[int],
ranks: list[int],
group_size: int,
):
if output_split_sizes is None or input_split_sizes is None:

View File

@ -1,4 +1,4 @@
from typing import Sequence
from collections.abc import Sequence
import torch
from torch.distributed._shard.metadata import ShardMetadata

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from functools import reduce
from typing import List, Optional, Union
from typing import Optional, Union
from torch.distributed.remote_device import _remote_device
@ -25,14 +25,14 @@ class ShardMetadata:
__slots__ = ["shard_offsets", "shard_sizes", "placement"]
shard_offsets: List[int]
shard_sizes: List[int]
shard_offsets: list[int]
shard_sizes: list[int]
placement: Optional[_remote_device]
def __init__(
self,
shard_offsets: List[int],
shard_sizes: List[int],
shard_offsets: list[int],
shard_sizes: list[int],
placement: Optional[Union[str, _remote_device]] = None,
):
self.shard_offsets = shard_offsets

View File

@ -1,4 +1,5 @@
from typing import Iterator, Tuple, Union
from collections.abc import Iterator
from typing import Tuple, Union
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Mapping, Union
from collections.abc import Mapping
from typing import Any, Union
import torch.optim as optim
from torch import Tensor
@ -28,7 +29,7 @@ class ShardedOptimizer(optim.Optimizer):
**optimizer_kwargs: the key-word arguments to initialize the optimizer.
"""
tensors: List[Tensor] = []
tensors: list[Tensor] = []
for value in named_params.values():
if isinstance(value, ShardedTensor):
tensors.extend(
@ -72,7 +73,7 @@ class ShardedOptimizer(optim.Optimizer):
"""
self._optim.step(closure)
def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> dict[str, Any]:
"""
Returned state and param_groups will contain parameter keys
instead of parameter indices like torch.optim.Optimizer.

View File

@ -356,7 +356,7 @@ def randn(
def init_from_local_shards(
local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False
local_shards: list[Shard], *global_size, process_group=None, init_rrefs=False
) -> ShardedTensor:
"""
Creates an :class:`ShardedTensor` from local shards and the global metadata.

View File

@ -8,7 +8,7 @@ import warnings
import weakref
from dataclasses import dataclass
from functools import reduce
from typing import Callable, cast, Dict, List, Optional, Sequence, TYPE_CHECKING
from typing import Callable, cast, Optional, TYPE_CHECKING
from typing_extensions import deprecated
import torch
@ -41,23 +41,25 @@ from .utils import (
if TYPE_CHECKING:
from collections.abc import Sequence
from torch.distributed._shard.metadata import ShardMetadata
# Tracking for sharded tensor objects.
_sharded_tensor_lock = threading.Lock()
_sharded_tensor_current_id = 0
_sharded_tensor_map: Dict[int, weakref.ReferenceType[ShardedTensor]] = {}
_sharded_tensor_map: dict[int, weakref.ReferenceType[ShardedTensor]] = {}
# Default sharded ops
_SHARDED_OPS: Dict[Callable, Callable] = {}
_SHARDED_OPS: dict[Callable, Callable] = {}
# Customized user ops
_CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {}
_CUSTOM_SHARDED_OPS: dict[Callable, Callable] = {}
def _register_remote_shards(
sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int
sharded_tensor_id: int, rrefs: list[rpc.RRef[Shard]], rpc_rank: int
):
with _sharded_tensor_lock:
if sharded_tensor_id not in _sharded_tensor_map:
@ -75,7 +77,7 @@ def _register_remote_shards(
class ShardedTensorBase(torch.Tensor):
_sharding_spec: shard_spec.ShardingSpec
_metadata: ShardedTensorMetadata
_local_shards: List[Shard]
_local_shards: list[Shard]
def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs):
# Use __new__ to construct a wrapper tensor, for recording tensor
@ -125,7 +127,7 @@ class ShardedTensorBase(torch.Tensor):
"""
return self._metadata
def local_shards(self) -> List[Shard]:
def local_shards(self) -> list[Shard]:
"""
Returns a list of :class:`Shard' corresponding to the
local shards for this rank. Returns an empty list if the current rank
@ -136,7 +138,7 @@ class ShardedTensorBase(torch.Tensor):
@classmethod
def _init_from_local_shards_and_global_metadata(
cls,
local_shards: List[Shard],
local_shards: list[Shard],
sharded_tensor_metadata: ShardedTensorMetadata,
sharding_spec=None,
) -> ShardedTensorBase:
@ -290,7 +292,7 @@ class ShardedTensor(ShardedTensorBase):
self._sharded_tensor_id = None
self._process_group = self._normalize_pg(process_group)
self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {}
self._remote_shards: dict[int, list[rpc.RRef[Shard]]] = {}
def _post_init(self):
# Initialize RPC if available.
@ -351,7 +353,7 @@ class ShardedTensor(ShardedTensorBase):
continue
if len(self.local_shards()) != 0:
rrefs: List[rpc.RRef[Shard]] = [
rrefs: list[rpc.RRef[Shard]] = [
rpc.RRef(shard) for shard in self.local_shards()
]
fut = rpc.rpc_async(
@ -430,7 +432,7 @@ class ShardedTensor(ShardedTensorBase):
world_size = dist.get_world_size(self._process_group)
rank_sizes = [0 for _ in range(world_size)]
max_rank_size = 0
shard_placement: Dict[ShardMetadata, tuple[int, int]] = {}
shard_placement: dict[ShardMetadata, tuple[int, int]] = {}
# collect sizes
for shard_md in self.metadata().shards_metadata:
shard_rank = cast(_remote_device, shard_md.placement).rank()
@ -440,7 +442,7 @@ class ShardedTensor(ShardedTensorBase):
rank_sizes[shard_rank] += shard_size(shard_md)
max_rank_size = max(max_rank_size, rank_sizes[shard_rank])
gather_list: Optional[List[torch.Tensor]]
gather_list: Optional[list[torch.Tensor]]
if rank == dst:
assert out is not None
if enforce_dtype:
@ -535,7 +537,7 @@ class ShardedTensor(ShardedTensorBase):
return self
# if not, returns a copy of this object on CPU
list_shards: List[Shard] = []
list_shards: list[Shard] = []
# move all local shards to cpu, and change metadata
for shard in self._local_shards:
cpu_tensor = shard.tensor.cpu(memory_format=memory_format) # type: ignore[call-arg]
@ -594,7 +596,7 @@ class ShardedTensor(ShardedTensorBase):
current_device = torch.device(torch.cuda.current_device())
# returns a copy of ShardedTensor on CUDA current device
list_shards: List[Shard] = []
list_shards: list[Shard] = []
# move all local shards to current device, and change metadata
# if local shards already on the current device, there's no
# real data movement, only the metadata are copied.
@ -683,7 +685,7 @@ class ShardedTensor(ShardedTensorBase):
return self
# returns a copy of ShardedTensor on CUDA current device
list_shards: List[Shard] = []
list_shards: list[Shard] = []
for shard in self._local_shards:
new_tensor = shard.tensor.to( # type: ignore[call-overload]
@ -726,7 +728,7 @@ class ShardedTensor(ShardedTensorBase):
@classmethod
def _init_from_local_shards(
cls,
local_shards: List[Shard],
local_shards: list[Shard],
*global_size,
process_group=None,
init_rrefs=False,
@ -746,7 +748,7 @@ class ShardedTensor(ShardedTensorBase):
# STEP 2. Validate metadata across ranks, and build a global sharded tensor
# metadata by gathering local ShardedTensorMetadata
gathered_metadatas: List[Optional[ShardedTensorMetadata]] = []
gathered_metadatas: list[Optional[ShardedTensorMetadata]] = []
if world_size > 1:
gathered_metadatas = [None for _ in range(world_size)]
@ -866,7 +868,7 @@ class ShardedTensor(ShardedTensorBase):
process_group = cls._normalize_pg(process_group)
current_rank = dist.get_rank() # intentional to get global rank
local_shards: List[Shard] = []
local_shards: list[Shard] = []
for shard_metadata in sharded_tensor_metadata.shards_metadata:
rank, _device = _parse_and_validate_remote_device(
process_group, shard_metadata.placement
@ -887,7 +889,7 @@ class ShardedTensor(ShardedTensorBase):
@classmethod
def _init_from_local_shards_and_global_metadata( # type: ignore[override]
cls,
local_shards: List[Shard],
local_shards: list[Shard],
sharded_tensor_metadata: ShardedTensorMetadata,
process_group=None,
init_rrefs=False,
@ -1190,11 +1192,11 @@ class ShardedTensor(ShardedTensorBase):
return self._metadata.tensor_properties.pin_memory
def _register_remote_shards(
self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int
self, remote_shards: list[rpc.RRef[Shard]], rpc_rank: int
):
self._remote_shards[rpc_rank] = remote_shards
def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]:
def remote_shards(self) -> dict[int, list[rpc.RRef[Shard]]]:
"""
Returns a Dict[int, RRef] with keys being the RPC rank and values
being RRefs to shards on that rank. Need to initialize the

View File

@ -7,12 +7,11 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import List
from torch.distributed._shard.sharded_tensor.logging_handlers import _log_handlers
__all__: List[str] = []
__all__: list[str] = []
def _get_or_create_logger() -> logging.Logger:

View File

@ -7,11 +7,10 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, List
__all__: List[str] = []
__all__: list[str] = []
_log_handlers: Dict[str, logging.Handler] = {
_log_handlers: dict[str, logging.Handler] = {
"default": logging.NullHandler(),
}

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass, field
from enum import Enum
from typing import List
import torch
from torch.distributed._shard.metadata import ShardMetadata
@ -87,7 +86,7 @@ class ShardedTensorMetadata:
"""
# Metadata about each shard of the Tensor
shards_metadata: List[ShardMetadata] = field(default_factory=list)
shards_metadata: list[ShardMetadata] = field(default_factory=list)
# Size of each dim of the overall Tensor.
size: torch.Size = field(default=torch.Size([]))

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import copy
from typing import List
import torch
import torch.distributed as dist
@ -44,7 +43,7 @@ def build_reshard_metadata(
st_size: torch.Size,
sharding_spec: shard_spec.ShardingSpec,
world_size: int,
) -> tuple[List[ShardMetadata], List[int]]:
) -> tuple[list[ShardMetadata], list[int]]:
"""
Based the given sharding spec, we calculate the offset and local shard size.
We then build a ShardMetadata on top of the calculation result.
@ -86,7 +85,7 @@ def reshuffle_local_shard(
sharding_spec: shard_spec.ShardingSpec,
resharding_spec: shard_spec.ShardingSpec,
pg: ProcessGroup,
) -> tuple[List[Shard], List[ShardMetadata]]:
) -> tuple[list[Shard], list[ShardMetadata]]:
"""
Reshuffle the local shard directly when the reshard dim is same as the original
sharding dim. Logically we do this in two step:
@ -155,7 +154,7 @@ def reshard_local_shard(
sharding_spec: shard_spec.ShardingSpec,
resharding_spec: shard_spec.ShardingSpec,
pg: ProcessGroup,
) -> tuple[List[Shard], List[ShardMetadata]]:
) -> tuple[list[Shard], list[ShardMetadata]]:
"""
Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is
different from the original sharding dim, we need to do two steps logically:
@ -198,7 +197,7 @@ def reshard_local_shard(
if rearrange_input:
# Need to re-arrange reshard_dim of local_tensor before all2all.
indices: List[int] = []
indices: list[int] = []
for metadata in shards_metadata:
offset_start_idx = metadata.shard_offsets[reshard_dim]
split_size = metadata.shard_sizes[reshard_dim]

View File

@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import List
import torch
from torch.distributed._shard.metadata import ShardMetadata
@ -43,7 +42,7 @@ class Shard:
@classmethod
def from_tensor_and_offsets(
cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int
cls, tensor: torch.Tensor, shard_offsets: list[int], rank: int
) -> "Shard":
"""
Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank.

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import collections.abc
import copy
from typing import List, Optional, Sequence, TYPE_CHECKING
from collections.abc import Sequence
from typing import Optional, TYPE_CHECKING
import torch
from torch.distributed import distributed_c10d as c10d, rpc
@ -109,13 +110,13 @@ def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True):
def build_metadata_from_local_shards(
local_shards: List[Shard],
local_shards: list[Shard],
global_size: torch.Size,
current_rank: int,
pg: c10d.ProcessGroup,
) -> ShardedTensorMetadata:
assert len(local_shards) > 0, "must have local shards!"
local_shard_metadatas: List[ShardMetadata] = []
local_shard_metadatas: list[ShardMetadata] = []
first_shard_dtype = local_shards[0].tensor.dtype
first_shard_layout = local_shards[0].tensor.layout

View File

@ -1,6 +1,6 @@
import abc
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import torch.nn as nn
from torch.distributed._shard.sharder import Sharder
@ -62,9 +62,9 @@ class ShardingPlan:
>>> )
"""
plan: Dict[str, Union[ShardingSpec, Sharder]]
output_plan: Optional[Dict[str, ShardingSpec]] = None
return_local_tensor: Optional[List[str]] = None
plan: dict[str, Union[ShardingSpec, Sharder]]
output_plan: Optional[dict[str, ShardingSpec]] = None
return_local_tensor: Optional[list[str]] = None
class ShardingPlanner(abc.ABC):

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import List, Optional
from typing import Optional
from torch.distributed._shard.metadata import ShardMetadata
@ -24,7 +24,7 @@ def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetad
def _find_nd_overlapping_shards(
shards: List[ShardMetadata], sharded_dims: List[int]
shards: list[ShardMetadata], sharded_dims: list[int]
) -> Optional[tuple[int, int]]:
# Each rank has len(sharded_dims) tuples. Each tuple represent the
# [begin, end] (inclusive) pair of that dimension.
@ -55,7 +55,7 @@ def _find_nd_overlapping_shards(
def _find_1d_overlapping_shards(
shards: List[ShardMetadata], dim: int
shards: list[ShardMetadata], dim: int
) -> Optional[tuple[int, int]]:
# (begin, end, index_in_shards). Begin and end are inclusive.
intervals = [
@ -69,7 +69,7 @@ def _find_1d_overlapping_shards(
return None
def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]):
def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]):
"""
Ensures none of the shards overlap with each other.
@ -82,7 +82,7 @@ def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]):
if not shards or len(shards) == 1:
return
sharded_dims: List[int] = []
sharded_dims: list[int] = []
for dim in range(len(shards[0].shard_offsets)):
for i in range(1, len(shards)):
if (

View File

@ -3,7 +3,7 @@ import functools
import operator
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Dict, List, TYPE_CHECKING
from typing import Callable, TYPE_CHECKING
import torch
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
@ -95,7 +95,7 @@ class ShardingSpec(ABC):
# Ops customized for a particular ShardingSpec.
_CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {}
_CUSTOM_SHARDING_SPEC_OPS: dict[str, dict[Callable, Callable]] = {}
def _has_custom_op(sharding_spec, op):
@ -148,7 +148,7 @@ class EnumerableShardingSpec(ShardingSpec):
each shard. Note that none of the shards should overlap.
"""
shards: List[ShardMetadata]
shards: list[ShardMetadata]
def __post_init__(self):
if len(self.shards) == 0:

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from typing import cast, List, Optional, TYPE_CHECKING, Union
from typing import cast, Optional, TYPE_CHECKING, Union
import torch
import torch.distributed as dist
@ -53,7 +53,7 @@ class ChunkShardingSpec(ShardingSpec):
ShardingDim = Union[int, str]
dim: ShardingDim
placements: List[Union[torch.distributed._remote_device, str]]
placements: list[Union[torch.distributed._remote_device, str]]
def __post_init__(self):
self._verify_dim(self.dim)
@ -134,7 +134,7 @@ class ChunkShardingSpec(ShardingSpec):
local_metadata = None
tensors_to_scatter = cast(
List[Optional[torch.Tensor]],
list[Optional[torch.Tensor]],
[None] * dist.get_world_size(process_group),
)
@ -195,7 +195,7 @@ class ChunkShardingSpec(ShardingSpec):
process_group, src_for_scatter
)
tensors_to_scatter_: Optional[List[torch.Tensor]] = None
tensors_to_scatter_: Optional[list[torch.Tensor]] = None
if current_rank == src_rank:
tensors_to_scatter_ = []
for t in tensors_to_scatter:

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
from typing import cast, List
from typing import cast
import torch
import torch.distributed as dist
@ -222,7 +222,7 @@ def _validate_embedding_bag_param(args, kwargs):
)
if include_last_offset and offsets is None:
raise ValueError('offsets is required for flag "include_last_offset"!')
if include_last_offset and cast(List[int], offsets)[-1] != input.size(0):
if include_last_offset and cast(list[int], offsets)[-1] != input.size(0):
raise ValueError(
'offsets need to have the input size in the end when the flag "include_last_offset" is on!'
)

View File

@ -3,19 +3,8 @@ import copy
import io
import math
import weakref
from typing import (
Any,
Callable,
cast,
Dict,
List,
Mapping,
MutableMapping,
NamedTuple,
Optional,
TYPE_CHECKING,
Union,
)
from collections.abc import Mapping, MutableMapping
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union
import torch
import torch.distributed as dist
@ -94,7 +83,7 @@ def _iterate_state_dict(
ranks_only: tuple[int, ...] = (),
type_check: bool = True,
non_blocking: bool = True,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Iterate through the state dict, applying the given functions to each tensor type.
Args:
@ -203,14 +192,14 @@ def _iterate_state_dict(
def _gather_state_dict(
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
*,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
cpu_offload: bool = False,
ranks_only: tuple[int, ...] = (),
type_check: bool = True,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Given a state_dict, this API gathers all the ShardedTensors or DTensors in
the state_dict.
@ -291,11 +280,11 @@ def _gather_state_dict(
def _offload_state_dict_to_cpu(
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
*,
ranks_only: tuple[int, ...] = (),
type_check: bool = True,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Given a state_dict, this API offload all the tensors to CPU memory.
@ -331,11 +320,11 @@ def _offload_state_dict_to_cpu(
def _copy_state_dict(
state_dict: Dict[str, Any],
copy_state_dict: Dict[str, Any],
state_dict: dict[str, Any],
copy_state_dict: dict[str, Any],
non_blocking: bool = False,
type_check: bool = True,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Copies all tensors in a given state dict into a different state_dict with the
same structure. Additionally, a copied state dict with the same value references
@ -380,8 +369,8 @@ def _copy_state_dict(
def _create_cpu_state_dict(
state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False
) -> Dict[str, Any]:
state_dict: dict[str, Any], pin_memory: bool = False, share_memory: bool = False
) -> dict[str, Any]:
"""
Given a state_dict, create another state_dict with the same structure and elements.
However, all tensors in the returned state_dict are new tensors on CPU. These
@ -449,8 +438,8 @@ def _create_cpu_state_dict(
def _check_state_dict_similarity(
state_dict: Dict[str, Any],
compared_state_dict: Dict[str, Any],
state_dict: dict[str, Any],
compared_state_dict: dict[str, Any],
) -> bool:
"""
Given two state_dicts, check if the structures are the same. And
@ -496,9 +485,9 @@ class _TensorInfo(NamedTuple):
def _broadcast_tensors(
full_state_dict: Dict[str, Any],
local_state_dict: Dict[str, Any],
keys: List[str],
full_state_dict: dict[str, Any],
local_state_dict: dict[str, Any],
keys: list[str],
device: torch.device,
pg: Optional[dist.ProcessGroup] = None,
) -> None:
@ -536,8 +525,8 @@ def _broadcast_tensors(
def _distribute_tensors(
local_state_dict: Dict[str, Any],
keys: List[str],
local_state_dict: dict[str, Any],
keys: list[str],
device: torch.device,
pg: Optional[dist.ProcessGroup] = None,
) -> None:
@ -578,8 +567,8 @@ def _distribute_tensors(
def _broadcast_state_dict(
full_state_dict: Dict[str, Any],
local_state_dict: Dict[str, Any],
full_state_dict: dict[str, Any],
local_state_dict: dict[str, Any],
device: torch.device,
pg: Optional[dist.ProcessGroup] = None,
strict: bool = False,
@ -637,8 +626,8 @@ def _broadcast_state_dict(
def _distribute_state_dict(
full_state_dict: Dict[str, Any],
local_state_dict: Dict[str, Any],
full_state_dict: dict[str, Any],
local_state_dict: dict[str, Any],
device: torch.device,
pg: Optional[dist.ProcessGroup] = None,
) -> None:
@ -672,8 +661,8 @@ def _distribute_state_dict(
# DCP.
PATH_ITEM = Union[str, int]
OBJ_PATH = tuple[PATH_ITEM, ...]
FLATTEN_MAPPING = Dict[str, OBJ_PATH]
STATE_DICT_TYPE = Dict[str, Any]
FLATTEN_MAPPING = dict[str, OBJ_PATH]
STATE_DICT_TYPE = dict[str, Any]
CONTAINER_TYPE = MutableMapping[PATH_ITEM, Any]
@ -731,14 +720,14 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
"""Set ``value`` in ``root_dict`` along the ``path`` object path."""
cur_container = cast(CONTAINER_TYPE, root_dict)
def extend_list(lst: List[Any], idx: int) -> None:
def extend_list(lst: list[Any], idx: int) -> None:
while len(lst) <= idx:
lst.append(None)
for i in range(1, len(path)):
prev_key = path[i - 1]
key = path[i]
def_val: Union[CONTAINER_TYPE, List[Any]] = {} if type(key) == str else []
def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) == str else []
if isinstance(cur_container, Mapping):
cur_container = cast(
@ -752,7 +741,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
key = path[-1]
if type(key) == int:
extend_list(cast(List[Any], cur_container), key)
extend_list(cast(list[Any], cur_container), key)
cur_container[key] = value

View File

@ -2,11 +2,12 @@ import math
import os
import socket
import uuid
from collections.abc import Generator
from contextlib import contextmanager
from datetime import timedelta
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.distributed._functional_collectives as funcol
@ -15,7 +16,7 @@ from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work
_group_name_to_store: Dict[str, c10d.Store] = {}
_group_name_to_store: dict[str, c10d.Store] = {}
def enable_symm_mem_for_group(group_name: str) -> None:
@ -95,7 +96,7 @@ def is_symm_mem_enabled_for_group(group_name: str) -> bool:
return _is_test_mode or group_name in _group_name_to_store
_group_name_to_workspace_tensor: Dict[str, Optional[torch.Tensor]] = {}
_group_name_to_workspace_tensor: dict[str, Optional[torch.Tensor]] = {}
def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory:
@ -139,7 +140,7 @@ def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory:
return _SymmetricMemory.rendezvous(tensor)
_backend_streams: Dict[int, torch.cuda.Stream] = {}
_backend_streams: dict[int, torch.cuda.Stream] = {}
def _get_backend_stream(priority: int = 0) -> torch.cuda.Stream:
@ -149,9 +150,9 @@ def _get_backend_stream(priority: int = 0) -> torch.cuda.Stream:
def _pipelined_multi_all_gather_and_consume(
shard: List[torch.Tensor],
shard_consumer: Callable[[List[torch.Tensor], int], None],
ag_out: List[torch.Tensor],
shard: list[torch.Tensor],
shard_consumer: Callable[[list[torch.Tensor], int], None],
ag_out: list[torch.Tensor],
group_name: str,
ag_out_needed: bool = True,
) -> None:
@ -195,11 +196,11 @@ def _pipelined_multi_all_gather_and_consume(
assert x.shape[0] * group_size == y.shape[0]
assert x.shape[1:] == y.shape[1:]
def copy_shard(dst: List[torch.Tensor], src: List[torch.Tensor]) -> None:
def copy_shard(dst: list[torch.Tensor], src: list[torch.Tensor]) -> None:
for d, s in zip(dst, src):
d.copy_(s)
def get_p2p_bufs(remote_rank: int) -> List[torch.Tensor]:
def get_p2p_bufs(remote_rank: int) -> list[torch.Tensor]:
offset_bytes = 0
bufs = []
for x in shard:
@ -216,7 +217,7 @@ def _pipelined_multi_all_gather_and_consume(
local_p2p_bufs = get_p2p_bufs(rank)
# shards[i] => shard from rank i
shards: List[List[torch.Tensor]] = [[] for _ in range(group_size)]
shards: list[list[torch.Tensor]] = [[] for _ in range(group_size)]
for x in ag_out:
for i, y in enumerate(x.chunk(group_size)):
shards[i].append(y)
@ -311,7 +312,7 @@ def _pipelined_all_gather_and_consume(
shard_consumer(shard, src_rank)
"""
def adapter(shard: List[torch.Tensor], rank: int) -> None:
def adapter(shard: list[torch.Tensor], rank: int) -> None:
shard_consumer(shard[0], rank)
_pipelined_multi_all_gather_and_consume(
@ -505,14 +506,14 @@ def _check_and_verify_fp8_all_gather_scale_mode(
def _fused_all_gather_matmul_impl(
mm_out_op: torch._ops.OpOverload,
A_shard: torch.Tensor,
Bs: List[torch.Tensor],
Bs: list[torch.Tensor],
A_scale: Optional[torch.Tensor],
kwargs_list: List[Dict[str, Any]],
out_dtypes: List[Optional[torch.dtype]],
kwargs_list: list[dict[str, Any]],
out_dtypes: list[Optional[torch.dtype]],
gather_dim: int,
group_name: str,
return_A: bool,
) -> tuple[Optional[torch.Tensor], List[torch.Tensor]]:
) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]:
if A_shard.dim() < 2:
raise ValueError("A_shard must be a matrix")
for B in Bs:
@ -563,7 +564,7 @@ def _fused_all_gather_matmul_impl(
A_scale_shard.shape[1],
)
def row_wise_sharded_consumer(shard: List[torch.Tensor], rank: int) -> None:
def row_wise_sharded_consumer(shard: list[torch.Tensor], rank: int) -> None:
for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)):
mm_out_op(
shard[0],
@ -630,12 +631,12 @@ def _fused_all_gather_matmul_impl(
@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
def _fused_all_gather_matmul_fallback(
A_shard: torch.Tensor,
Bs: List[torch.Tensor],
Bs: list[torch.Tensor],
gather_dim: int,
group_name: str,
*,
return_A: bool = True,
) -> tuple[Optional[torch.Tensor], List[torch.Tensor]]:
) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]:
group_size = c10d._get_group_size_by_name(group_name)
A = torch.ops._c10d_functional.all_gather_into_tensor(
A_shard.contiguous(), group_size, group_name
@ -652,12 +653,12 @@ def _fused_all_gather_matmul_fallback(
@torch.library.impl(lib, "fused_all_gather_matmul", "CUDA")
def _fused_all_gather_matmul(
A_shard: torch.Tensor,
Bs: List[torch.Tensor],
Bs: list[torch.Tensor],
gather_dim: int,
group_name: str,
*,
return_A: bool = True,
) -> tuple[Optional[torch.Tensor], List[torch.Tensor]]:
) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]:
"""
Perform the following logic with micro-pipelined computation and
communication:
@ -703,7 +704,7 @@ def _fused_all_gather_matmul(
def _should_use_fused_all_gather_matmul_native(
A_shard: torch.Tensor,
Bs: List[torch.Tensor],
Bs: list[torch.Tensor],
gather_dim: int,
group_name: str,
) -> bool:
@ -806,9 +807,9 @@ def _should_use_multimem_all_gather_matmul(
def _multimem_all_gather_matmul(
A_shard: torch.Tensor,
Bs: List[torch.Tensor],
Bs: list[torch.Tensor],
group_name: str,
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
group = c10d._resolve_process_group(group_name)
A_shape = torch.Size((A_shard.shape[0] * group.size(), *A_shard.shape[1:]))
symm_mem = get_symm_mem_workspace(
@ -822,16 +823,16 @@ def _multimem_all_gather_matmul(
@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta")
def _fused_all_gather_scaled_matmul_fallback(
A_shard: torch.Tensor,
Bs: List[torch.Tensor],
Bs: list[torch.Tensor],
A_scale: torch.Tensor,
B_scales: List[torch.Tensor],
B_scales: list[torch.Tensor],
gather_dim: int,
group_name: str,
biases: List[Optional[torch.Tensor]],
result_scales: List[Optional[torch.Tensor]],
out_dtypes: List[Optional[torch.dtype]],
use_fast_accum: List[bool],
) -> tuple[torch.Tensor, List[torch.Tensor]]:
biases: list[Optional[torch.Tensor]],
result_scales: list[Optional[torch.Tensor]],
out_dtypes: list[Optional[torch.dtype]],
use_fast_accum: list[bool],
) -> tuple[torch.Tensor, list[torch.Tensor]]:
out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes)
group_size = c10d._get_group_size_by_name(group_name)
@ -896,16 +897,16 @@ def _fused_all_gather_scaled_matmul_fallback(
@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA")
def _fused_all_gather_scaled_matmul(
A_shard: torch.Tensor,
Bs: List[torch.Tensor],
Bs: list[torch.Tensor],
A_scale: torch.Tensor,
B_scales: List[torch.Tensor],
B_scales: list[torch.Tensor],
gather_dim: int,
group_name: str,
biases: List[Optional[torch.Tensor]],
result_scales: List[Optional[torch.Tensor]],
out_dtypes: List[Optional[torch.dtype]],
use_fast_accum: List[bool],
) -> tuple[torch.Tensor, List[torch.Tensor]]:
biases: list[Optional[torch.Tensor]],
result_scales: list[Optional[torch.Tensor]],
out_dtypes: list[Optional[torch.dtype]],
use_fast_accum: list[bool],
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""
Perform the following logic with micro-pipelined computation and
communication:
@ -976,7 +977,7 @@ def _fused_all_gather_scaled_matmul(
def make_contiguous_for_perm(
t: torch.Tensor,
perm: List[int],
perm: list[int],
) -> torch.Tensor:
"""
Restride `t` such that `t.permute(perm)` is contiguous.
@ -1005,7 +1006,7 @@ def _fused_matmul_reduce_scatter_impl(
A: torch.Tensor,
B: torch.Tensor,
A_scale: Optional[torch.Tensor],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
out_dtype: Optional[torch.dtype],
reduce_op: str,
scatter_dim: int,
@ -1237,8 +1238,8 @@ def restride_A_for_fused_matmul_reduce_scatter(
def _maybe_convert_scalar_types_to_dtypes(
scalar_types: List[Any],
) -> List[Optional[torch.dtype]]:
scalar_types: list[Any],
) -> list[Optional[torch.dtype]]:
"""
When a list of `torch.dtype`s is passed through the dispatcher as
`ScalarType[]`, it is converted to a list of scalar type enum values. This
@ -1270,7 +1271,7 @@ def _maybe_convert_scalar_types_to_dtypes(
if any(not isinstance(x, (type(None), int)) for x in scalar_types):
return scalar_types
dtypes: List[Optional[torch.dtype]] = []
dtypes: list[Optional[torch.dtype]] = []
for scalar_type in scalar_types:
if scalar_type is None:
dtypes.append(scalar_type)
@ -1507,7 +1508,8 @@ def _low_contention_reduce_scatter(
# =============================================================================
from typing import Any, overload, Sequence, TYPE_CHECKING, Union
from collections.abc import Sequence
from typing import Any, overload, TYPE_CHECKING, Union
from torch.types import _device, _dtype, _int

View File

@ -1,7 +1,7 @@
from copy import deepcopy
from datetime import timedelta
from functools import partial, wraps
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, TypeVar, Union
from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
import torch
@ -111,9 +111,9 @@ class _FSDPModMemStats:
def __init__(self, mod_fqn: str) -> None:
self.mod_fqn = mod_fqn
self.local_peak: Dict[torch.device, int] = {}
self.snapshots: Dict[
_FSDPModState, List[Dict[torch.device, Dict[str, int]]]
self.local_peak: dict[torch.device, int] = {}
self.snapshots: dict[
_FSDPModState, list[dict[torch.device, dict[str, int]]]
] = {}
@ -169,7 +169,7 @@ class FSDPMemTracker(MemTracker):
self._in_fake_mode: bool = False
self._fsdp_mod_to_saved_methods: WeakIdKeyDictionary = WeakIdKeyDictionary()
self._saved_collectives: _SavedCollectives
self._ref_class: Type[_RefType] = _FSDPRefType
self._ref_class: type[_RefType] = _FSDPRefType
def _instrument_fsdp_sharded_params_grads(
self, fsdp_param_group: FSDPParamGroup
@ -190,8 +190,8 @@ class FSDPMemTracker(MemTracker):
def _fsdp_state_pre_forward(
self,
fsdp_mod: FSDPModule,
orig_fsdp_state_pre_fw: Callable[_P, tuple[tuple[Unpack[_Ts]], Dict[str, Any]]],
) -> Callable[_P, tuple[tuple[Unpack[_Ts]], Dict[str, Any]]]:
orig_fsdp_state_pre_fw: Callable[_P, tuple[tuple[Unpack[_Ts]], dict[str, Any]]],
) -> Callable[_P, tuple[tuple[Unpack[_Ts]], dict[str, Any]]]:
# We capture memory snapshots before and after ``FSDPState._pre_forward`` to attribute the `unsharded` params
# and `all_gather` buffers. There are three cases:
# Case 1: If the module is not in the ``memory_tracking`` dictionary, create a new ``_FSDPModMemStats``
@ -208,7 +208,7 @@ class FSDPMemTracker(MemTracker):
@wraps(orig_fsdp_state_pre_fw)
def inner(
*args: _P.args, **kwargs: _P.kwargs
) -> tuple[tuple[Unpack[_Ts]], Dict[str, Any]]:
) -> tuple[tuple[Unpack[_Ts]], dict[str, Any]]:
mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod)
assert mod_fqn is not None
if fsdp_mod not in self.memory_tracking:
@ -538,7 +538,7 @@ class FSDPMemTracker(MemTracker):
def barrier(
group: Union[ProcessGroup, None] = dist.GroupMember.WORLD,
async_op: bool = False,
device_ids: Union[List[int], None] = None,
device_ids: Union[list[int], None] = None,
) -> Union[Work, None]:
if self._in_fake_mode:
return None

View File

@ -1,5 +1,6 @@
import copy
from typing import cast, Dict, List, OrderedDict, TypedDict
from collections import OrderedDict
from typing import cast, TypedDict
import numpy as np
@ -15,10 +16,10 @@ from torch.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStat
class ModOrder(TypedDict):
fw_pre_order: List[str]
bw_pre_order: List[str]
fw_post_order: List[str]
bw_post_order: List[str]
fw_pre_order: list[str]
bw_pre_order: list[str]
fw_post_order: list[str]
bw_post_order: list[str]
class ModRuntime(TypedDict):
@ -60,18 +61,18 @@ class ModStats(TypedDict):
# Number of piecewise-linear functions used for approximating ac tradeoff curve
n_segments: int
# Slopes of the of piecewise-linear functions
slopes: List[float]
slopes: list[float]
# Intercepts of the of piecewise-linear functions
intercepts: List[float]
intercepts: list[float]
# X breakpoints of the of piecewise-linear functions
breakpoints: List[float]
breakpoints: list[float]
# Original trade-off curves
tradeoff_curve: OrderedDict[float, float]
class ModuleInfo(TypedDict):
mod_order: ModOrder
mod_stats: List[ModStats]
mod_stats: list[ModStats]
def aggregate_stats(
@ -96,12 +97,12 @@ def aggregate_stats(
"""
# Memory stats
mod_mem_stats: Dict[torch.nn.Module, _ModMemStats] = dict(
mod_mem_stats: dict[torch.nn.Module, _ModMemStats] = dict(
copy.deepcopy(mem_tracker.memory_tracking)
)
# Runtime stats
mod_runtime_stats: Dict[str, ModRuntime] = {
mod_runtime_stats: dict[str, ModRuntime] = {
fqn: {"fw": v["fw"], "bw": v["bw"]}
for fqn, v in runtime_estimator.mod_runtimes.items()
}
@ -116,7 +117,7 @@ def aggregate_stats(
# Selective Activation Checkpointing stats
sac_estimator.pwlf_sac_tradeoff_curve()
mod_sac_tradeoff_stats: Dict[str, SACTradeOffStats] = copy.deepcopy(
mod_sac_tradeoff_stats: dict[str, SACTradeOffStats] = copy.deepcopy(
sac_estimator.sac_mod_tradeoff_stats
)
@ -192,10 +193,10 @@ class Node(ModStats):
class Graph:
def __init__(self, n: int) -> None:
self.nodes: List[Node] = []
self.name2node: Dict[str, Node] = {}
self.nodes: list[Node] = []
self.name2node: dict[str, Node] = {}
self.ad_matrix = np.zeros((n, n))
self.fw_post_order: List[str] = []
self.fw_post_order: list[str] = []
def add_node(self, node: Node) -> None:
self.nodes.append(node)

View File

@ -5,7 +5,7 @@ import warnings
from copy import deepcopy
from enum import auto, Enum
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import Self
import torch
@ -116,8 +116,8 @@ class _ModMemStats:
self.buffer_mem: int
self.input_mem: int
self.output_mem: int
self.local_peak: Dict[torch.device, int] = {}
self.snapshots: Dict[_ModState, List[Dict[torch.device, Dict[str, int]]]] = {}
self.local_peak: dict[torch.device, int] = {}
self.snapshots: dict[_ModState, list[dict[torch.device, dict[str, int]]]] = {}
class _WeakRefInfo:
@ -171,7 +171,7 @@ class _WeakRefInfo:
return self.mem_consumed
@staticmethod
def get_untyped_storages(t: torch.Tensor) -> Set[torch.UntypedStorage]:
def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
"""
Recursively extracts untyped storages from a tensor or its subclasses.
@ -242,7 +242,7 @@ def _rounding_fn(value: int, divisor: int, precision: int) -> Union[float, int]:
return value if divisor == 1 else round(value / divisor, precision)
def _print_snapshot(snapshot: Dict[torch.device, Dict[str, int]], units: str) -> None:
def _print_snapshot(snapshot: dict[torch.device, dict[str, int]], units: str) -> None:
if len(snapshot) == 0:
print("No memory tracked.")
return
@ -261,7 +261,7 @@ def _print_snapshot(snapshot: Dict[torch.device, Dict[str, int]], units: str) ->
def _print_snapshot_tabular(
snapshot: Dict[torch.device, Dict[str, int]], units: str
snapshot: dict[torch.device, dict[str, int]], units: str
) -> None:
if len(snapshot) == 0:
print("No memory tracked.")
@ -287,7 +287,7 @@ def _print_snapshot_tabular(
def _print_state_snapshots(
snapshots: Dict[_State, List[Dict[torch.device, Dict[str, int]]]], units: str
snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str
) -> None:
for state, snapshot_list in snapshots.items():
print(f"{state}")
@ -298,7 +298,7 @@ def _print_state_snapshots(
def _print_state_snapshots_tabular(
snapshots: Dict[_State, List[Dict[torch.device, Dict[str, int]]]], units: str
snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str
) -> None:
try:
from tabulate import tabulate
@ -393,9 +393,9 @@ class MemTracker(TorchDispatchMode):
def __init__(self) -> None:
self.memory_tracking = WeakIdKeyDictionary()
self._curr_mem_snap: Dict[torch.device, Dict[str, int]] = {}
self._peak_mem: Dict[torch.device, int] = {}
self._peak_mem_snap: Dict[torch.device, Dict[str, int]] = {}
self._curr_mem_snap: dict[torch.device, dict[str, int]] = {}
self._peak_mem: dict[torch.device, int] = {}
self._peak_mem_snap: dict[torch.device, dict[str, int]] = {}
self._param_to_grad_hook_handles = WeakIdKeyDictionary()
self._optimizer_hook_handles: Optional[
tuple[RemovableHandle, RemovableHandle]
@ -404,7 +404,7 @@ class MemTracker(TorchDispatchMode):
self._WINFO = WeakIdKeyDictionary()
self._mod_tracker = ModTracker()
# This is a general memory tracker which can be used with any ``_RefType`` subclass
self._ref_class: Type[_RefType] = _MemRefType
self._ref_class: type[_RefType] = _MemRefType
# Flags to track if we are in the AC region or optimizer step region
self._in_opt: bool = False
self._in_ac: bool = False
@ -461,7 +461,7 @@ class MemTracker(TorchDispatchMode):
t: torch.Tensor,
reftype: _RefType,
update_existing: bool = False,
) -> Set[_WeakRefInfo]:
) -> set[_WeakRefInfo]:
sts = _WeakRefInfo.get_untyped_storages(t)
winfos = set()
for st in sts:
@ -565,7 +565,7 @@ class MemTracker(TorchDispatchMode):
def get_tracker_snapshot(
self, type: str = "current"
) -> Dict[torch.device, Dict[str, int]]:
) -> dict[torch.device, dict[str, int]]:
"""
Capture a snapshot of the memory usage breakdown per device, based on the specified type.
@ -865,7 +865,7 @@ class MemTracker(TorchDispatchMode):
tabulate (bool, optional): Whether to display the snapshot in a tabular format. Defaults to False.
"""
def natural_sort_key(s: str) -> List[Union[int, str]]:
def natural_sort_key(s: str) -> list[Union[int, str]]:
return [
int(text) if text.isdigit() else text.lower()
for text in re.split("([0-9]+)", s)

View File

@ -2,8 +2,9 @@
import operator
import pickle
from collections import defaultdict
from collections.abc import Sequence
from itertools import chain
from typing import Any, Callable, Dict, List, no_type_check, Sequence, TYPE_CHECKING
from typing import Any, Callable, no_type_check, TYPE_CHECKING
import torch
import torch.nn as nn
@ -72,12 +73,12 @@ class MemoryTracker:
def __init__(self) -> None:
torch._C._log_api_usage_once("torch.distributed.memory_tracker")
self._hooks: List[RemovableHandle] = []
self._operator_names: Dict[str, int] = defaultdict(int)
self.memories_allocated: Dict[int, Dict[str, float]] = defaultdict()
self.memories_active: Dict[int, Dict[str, float]] = defaultdict()
self.memories_reserved: Dict[int, Dict[str, float]] = defaultdict()
self._markers: Dict[str, int] = defaultdict(int)
self._hooks: list[RemovableHandle] = []
self._operator_names: dict[str, int] = defaultdict(int)
self.memories_allocated: dict[int, dict[str, float]] = defaultdict()
self.memories_active: dict[int, dict[str, float]] = defaultdict()
self.memories_reserved: dict[int, dict[str, float]] = defaultdict()
self._markers: dict[str, int] = defaultdict(int)
self._cur_module_name: str = ""
self._op_index: int = 0
self._num_cuda_retries: int = 0
@ -133,7 +134,7 @@ class MemoryTracker:
The number of the top operators can be configured.
"""
op_diff: Dict[str, float] = defaultdict(float)
op_diff: dict[str, float] = defaultdict(float)
op_name, previous_allocated_memory = self.memories_allocated[0]
for i in range(1, self._op_index):
op_name, current_allocated_memory = self.memories_allocated[i]

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import warnings
import weakref
from typing import Callable, Optional, Set
from typing import Callable, Optional
import torch
from torch.autograd.graph import register_multi_grad_hook
@ -48,7 +48,7 @@ class ModTracker:
"""
parents: Set[str]
parents: set[str]
"""
A Set containing the fqn for each module currently running their forward
"""

View File

@ -2,7 +2,7 @@
import math
import os
from collections import defaultdict
from typing import Any, Callable, Dict, List, Set
from typing import Any, Callable
from typing_extensions import Self
import torch
@ -121,13 +121,13 @@ class RuntimeEstimator(TorchDispatchMode):
runtime_estimator.display_modulewise_stats()
"""
_float_types: Set[torch.dtype] = {
_float_types: set[torch.dtype] = {
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
}
_no_fallback_kernel: Set[torch._ops._OpNamespace] = set()
_no_fallback_kernel: set[torch._ops._OpNamespace] = set()
fake_mode: FakeTensorMode
def __init__(self) -> None:
@ -135,13 +135,13 @@ class RuntimeEstimator(TorchDispatchMode):
self._estimate: Callable
self._estimate_mode_type: str
self._mod_tracker = ModTracker()
self.mod_runtimes: Dict[str, Dict[str, float]] = defaultdict(
self.mod_runtimes: dict[str, dict[str, float]] = defaultdict(
lambda: defaultdict(lambda: 0.0)
)
self.mod_fw_pre_order: List[str] = []
self.mod_bw_pre_order: List[str] = []
self.mod_fw_post_order: List[str] = []
self.mod_bw_post_order: List[str] = []
self.mod_fw_pre_order: list[str] = []
self.mod_bw_pre_order: list[str] = []
self.mod_fw_post_order: list[str] = []
self.mod_bw_post_order: list[str] = []
self.total_runtime: float = 0.0
# Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950

View File

@ -4,7 +4,7 @@ import sys
import warnings
from collections import OrderedDict
from dataclasses import astuple, dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Set
from typing import Any, NamedTuple, Optional
from typing_extensions import Self
import torch
@ -42,7 +42,7 @@ _PYTORCH_MIN_ALLOCATE = (
)
def _get_untyped_storages(t: torch.Tensor) -> Set[torch.UntypedStorage]:
def _get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
"""
Retrieves untyped storages from a `torch.Tensor` or one of its traceable wrapper-subclass.
@ -74,7 +74,7 @@ def _get_untyped_storages(t: torch.Tensor) -> Set[torch.UntypedStorage]:
return flattened_tensor_storages
def _display_stats_tabular(headers: List[str], table_data: List[List[Any]]) -> None:
def _display_stats_tabular(headers: list[str], table_data: list[list[Any]]) -> None:
try:
from tabulate import tabulate
except ImportError as err:
@ -125,7 +125,7 @@ class _SACModMetadata:
start_idx: int
force_store_random: bool
sac_metadata: List[_SACMetadata]
sac_metadata: list[_SACMetadata]
@dataclass
@ -144,13 +144,13 @@ class SACStats:
force_store_random (bool): Whether to force store random operator results.
"""
func_names: List[str]
runtimes: List[float]
memory: List[int]
view_like_ops: List[int]
rand_ops: List[int]
saved_autograd_ops: List[int]
inplace_ops: List[tuple[int, int]]
func_names: list[str]
runtimes: list[float]
memory: list[int]
view_like_ops: list[int]
rand_ops: list[int]
saved_autograd_ops: list[int]
inplace_ops: list[tuple[int, int]]
force_store_random: bool
@ -166,7 +166,7 @@ class MSPS(NamedTuple):
msps (float): Memory per second calculated as memory/runtime.
"""
func_names: Set[str]
func_names: set[str]
op_idx: int
memory: int
runtime: float
@ -189,9 +189,9 @@ class SACTradeOffStats:
"""
n_segments: int
slopes: List[float]
intercepts: List[float]
fit_breaks: List[float]
slopes: list[float]
intercepts: list[float]
fit_breaks: list[float]
tradeoff_curve: OrderedDict[float, float]
sac_memory: int
sac_runtime: float
@ -210,11 +210,11 @@ class SACGreedyOrderMeta:
msps_meta (List[MSPS]): List of Memory and Runtime Statistics for operators.
"""
recomputed_ops: Set[int]
stored_ops: Set[int]
inplace_op_groups: Dict[int, Set[int]]
random_ops_group: Dict[int, Set[int]]
msps_meta: List[MSPS]
recomputed_ops: set[int]
stored_ops: set[int]
inplace_op_groups: dict[int, set[int]]
random_ops_group: dict[int, set[int]]
msps_meta: list[MSPS]
class SACEstimator(TorchDispatchMode):
@ -251,17 +251,17 @@ class SACEstimator(TorchDispatchMode):
"""
def __init__(self) -> None:
self.sac_mod_stats: Dict[str, SACStats] = {}
self.sac_mod_tradeoff_stats: Dict[str, SACTradeOffStats] = {}
self.sac_mod_greedy_order_meta: Dict[str, SACGreedyOrderMeta] = {}
self.sac_mod_stats: dict[str, SACStats] = {}
self.sac_mod_tradeoff_stats: dict[str, SACTradeOffStats] = {}
self.sac_mod_greedy_order_meta: dict[str, SACGreedyOrderMeta] = {}
self._mod_tracker = ModTracker()
self._sac_metadata: List[_SACMetadata] = []
self._sac_mod_metadata: Dict[str, _SACModMetadata] = {}
self._leaf_modules: Set[str] = set()
self._sac_metadata: list[_SACMetadata] = []
self._sac_mod_metadata: dict[str, _SACModMetadata] = {}
self._leaf_modules: set[str] = set()
self._saved_tensor_hook_ctx = torch.autograd.graph.saved_tensors_hooks(
self._pack_hook, lambda x: x
)
self._saved_tensor_ids: Set[int] = set()
self._saved_tensor_ids: set[int] = set()
self._estimate_runtime = RuntimeEstimator._roofline_estimate
def _pack_hook(self, x: torch.Tensor) -> torch.Tensor:
@ -313,7 +313,7 @@ class SACEstimator(TorchDispatchMode):
return all(not isinstance(x, torch.Tensor) for x in flat_inputs)
def _get_sac_stats(
self, data: List[_SACMetadata], force_store_random: bool
self, data: list[_SACMetadata], force_store_random: bool
) -> SACStats:
# 1. Ignore the operations that should be skipped by SAC such as aten.detach.default because autograd
# inserts those during backward and it breaks the fwd-bwd alignment
@ -381,12 +381,12 @@ class SACEstimator(TorchDispatchMode):
)
def _get_inplace_metadata(
self, func: Any, out_storages: Set[UntypedStorage]
) -> tuple[int, tuple[int, ...], Dict[str, tuple[int, ...]]]:
self, func: Any, out_storages: set[UntypedStorage]
) -> tuple[int, tuple[int, ...], dict[str, tuple[int, ...]]]:
# 1. Get the current index of the metadata obtained so far
curr_idx = len(self._sac_metadata)
# 2. Get the set of active modules that are not leaf
active_mod_fqns: Set[str] = {
active_mod_fqns: set[str] = {
par for par in self._mod_tracker.parents if par not in self._leaf_modules
}
# 3. Output ids are the identifies of the storage objects corresponding to the tensors
@ -397,7 +397,7 @@ class SACEstimator(TorchDispatchMode):
op_idx = curr_idx
# 5. Initialize the parent op ids of the inplace op for each of the active modules
mod_op_parent_idxs: Dict[str, int] = {
mod_op_parent_idxs: dict[str, int] = {
mod_fqn: -1 for mod_fqn in active_mod_fqns
}
for i, d in enumerate(self._sac_metadata):
@ -430,9 +430,9 @@ class SACEstimator(TorchDispatchMode):
# 1. Get the runtime estimate
out, op_time = self._estimate_runtime(func, args, kwargs)
flat_outs, _ = tree_flatten(out)
out_storages_cuda: Set[UntypedStorage] = set()
out_storages_cpu: Set[UntypedStorage] = set()
cuda_devices: Set[torch.device] = set()
out_storages_cuda: set[UntypedStorage] = set()
out_storages_cpu: set[UntypedStorage] = set()
cuda_devices: set[torch.device] = set()
for o in flat_outs:
if isinstance(o, torch.Tensor):
if o.device.type == "cuda":
@ -496,8 +496,8 @@ class SACEstimator(TorchDispatchMode):
# 1. inplace_op_groups: A dictionary from the top-most parent of inplace-ops to the inplace-ops in the group
# The top-most op can itself be an inplace-op or can be a non-inplace op.
# 2. inplace_op_to_group_head: A dictionary that maps all the inplace-ops to their respective group heads.
inplace_op_groups: Dict[int, Set[int]] = {}
inplace_op_to_group_head: Dict[int, int] = dict(sac_stats.inplace_ops)
inplace_op_groups: dict[int, set[int]] = {}
inplace_op_to_group_head: dict[int, int] = dict(sac_stats.inplace_ops)
# Initialize inplace_op_groups using inplace_op_to_group_head
for op_idx, group_head_idx in inplace_op_to_group_head.items():
@ -508,7 +508,7 @@ class SACEstimator(TorchDispatchMode):
# as a group. This is because, they affect the ranom seed generator. If force_store_random is set True,
# all of the random ops will be stored by default. For easy of manageability, we store the top-most random op
# as the leader of the random_ops_group.
random_ops_group: Dict[int, Set[int]] = {}
random_ops_group: dict[int, set[int]] = {}
random_group_head_idx = min(sac_stats.rand_ops, default=-1)
has_rand_ops = bool(sac_stats.rand_ops)
if has_rand_ops:
@ -521,8 +521,8 @@ class SACEstimator(TorchDispatchMode):
# b) If any op in the group is random and force_store_random is set, then entire group will be stored.
# c) If none of ops in the group are random and the head of the group is not an in-place op, then
# this group can be considered for recomputation in its entireity
stored_ops: Set[int] = set()
recomputed_ops: Set[int] = set()
stored_ops: set[int] = set()
recomputed_ops: set[int] = set()
# Case 1:
if has_rand_ops and sac_stats.force_store_random:
stored_ops.add(random_group_head_idx)
@ -541,7 +541,7 @@ class SACEstimator(TorchDispatchMode):
stored_ops.add(group_head_idx)
# The potential recompute candidates are populated as:
recompute_candidates: Set[int] = set()
recompute_candidates: set[int] = set()
# 1) The random group head if it is not stored
if has_rand_ops and random_group_head_idx not in stored_ops:
recompute_candidates.add(random_group_head_idx)
@ -557,7 +557,7 @@ class SACEstimator(TorchDispatchMode):
)
# We define msps for a recomp candidate as the ratio of memory/runtime aka memory savings per second
msps_meta: List[MSPS] = []
msps_meta: list[MSPS] = []
for cand_idx in recompute_candidates:
op_indices = {cand_idx}
if cand_idx in inplace_op_groups:
@ -598,7 +598,7 @@ class SACEstimator(TorchDispatchMode):
greedy_order_meta.msps_meta,
)
# 1. Intitialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops
recomp_indices: Set[int] = set()
recomp_indices: set[int] = set()
for r_idx in recomputed_ops:
recomp_indices.add(r_idx)
if r_idx in inplace_op_groups:
@ -628,7 +628,7 @@ class SACEstimator(TorchDispatchMode):
recomp_runtime / sac_runtime
)
# 6. Finally, we add the memory and recomputation time of the always stored ops.
stored_indices: Set[int] = set()
stored_indices: set[int] = set()
for s_idx in stored_ops:
stored_indices.add(s_idx)
if s_idx in inplace_op_groups:
@ -653,7 +653,7 @@ class SACEstimator(TorchDispatchMode):
# save prediction graph
def save_prediction_graph(
pwlf_: pwlf.PiecewiseLinFit, x: List[float], y: List[float], filename: str
pwlf_: pwlf.PiecewiseLinFit, x: list[float], y: list[float], filename: str
) -> None:
try:
import matplotlib.pyplot as plt # type: ignore[import-not-found]
@ -811,8 +811,8 @@ class SACEstimator(TorchDispatchMode):
recomp_runtime: float = 0.0
def append_row(
op_indices: Set[int],
func_names: Set[str],
op_indices: set[int],
func_names: set[str],
msps: Optional[float] = None,
stored: Optional[bool] = False,
recomputed: Optional[bool] = False,
@ -839,7 +839,7 @@ class SACEstimator(TorchDispatchMode):
)
for op_idx in recomputed_ops:
op_indices: Set[int] = {op_idx}
op_indices: set[int] = {op_idx}
if op_idx in inplace_op_groups:
op_indices.update(inplace_op_groups[op_idx])
if op_idx in random_ops_group:

View File

@ -1,7 +1,7 @@
import logging
import math
from enum import IntEnum
from typing import Dict, List, Optional
from typing import Optional
from torch.distributed._tools.ilp_utils import Graph, is_submodule
from torch.distributed._tools.sac_estimator import SACStats
@ -36,9 +36,9 @@ def sac_milp(
graph: Graph,
memory_budget: float,
world_size: int = 1,
ac_units: Optional[List[str]] = None,
fsdp_units: Optional[List[str]] = None,
) -> tuple[Dict[str, float], float, int]:
ac_units: Optional[list[str]] = None,
fsdp_units: Optional[list[str]] = None,
) -> tuple[dict[str, float], float, int]:
"""
MILP to decide which modules to AC and how much memory to discard.
The objective is to minimize recomputation time.
@ -224,7 +224,7 @@ class SACDecision(IntEnum):
def get_optimal_checkpointing_policy_per_module(
sac_stats: SACStats, memory_budget: float
) -> List[int]:
) -> list[int]:
"""
This is adapted from --
https://github.com/facebookresearch/xformers/blob/c6c0ac31f1b08542a0bc27278c6ed10f825f6963/xformers/checkpoint.py#L375

View File

@ -1,9 +1,10 @@
# mypy: allow-untyped-defs
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterator
from enum import auto, Enum
from functools import partial
from typing import Any, Callable, Dict, Iterator, Optional
from typing import Any, Callable, Optional
import torch
import torch.nn as nn
@ -69,10 +70,10 @@ class ActivationWrapper(torch.nn.Module, ABC):
@staticmethod
def _post_state_dict_hook(
module: nn.Module,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
*args: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
_post_state_dict_hook() is called after the state_dict() of this FSDP module is executed.
@ -87,7 +88,7 @@ class ActivationWrapper(torch.nn.Module, ABC):
@staticmethod
def _pre_load_state_dict_hook(
module: nn.Module,
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
prefix: str,
*args: Any,
) -> None:

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import inspect
from abc import ABC, abstractmethod
from typing import Dict, Type
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
@ -15,7 +14,7 @@ from torch.optim import Optimizer
# Contains the mappings between the regular and overlapped optimizer types.
_registered_overlapped_optims: Dict[Type, Type] = {}
_registered_overlapped_optims: dict[type, type] = {}
def register_overlapped(optim_cls):
@ -33,7 +32,7 @@ def register_overlapped(optim_cls):
class OverlappedOptimizer(ABC):
def __init__(self, optim_cls: Type) -> None:
def __init__(self, optim_cls: type) -> None:
"""
Initialize the OverlappedOptimizer.
@ -61,7 +60,7 @@ class OverlappedOptimizer(ABC):
class _OverlappedStandardOptimizer(OverlappedOptimizer):
"""Overlaps a regular ``Optimizer``."""
def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None:
def __init__(self, optim_cls: type, params, *optim_args, **optim_kwargs) -> None:
super().__init__(optim_cls)
f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
self._opt_hook_state = _OptimizerHookState(f_optim, params)
@ -82,7 +81,7 @@ class _OverlappedStandardOptimizer(OverlappedOptimizer):
)
def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs):
def _as_overlapped_optim(optim_cls: type, params, *args, **kwargs):
"""Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``."""
for clz in inspect.getmro(optim_cls):
try:

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import weakref
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
@ -46,7 +46,7 @@ def _perform_local_step(
# expects `None` in a list position to indicate that the corresponding
# parameter should not be updated
num_local_optim_params = len(zero.optim.param_groups[0]["params"])
gradients: List[Optional[torch.Tensor]] = [
gradients: list[Optional[torch.Tensor]] = [
_NO_PARAM_UPDATE for _ in range(num_local_optim_params)
]
assert (

View File

@ -1,14 +1,14 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, List, no_type_check
from typing import Any, Callable, no_type_check
import torch
import torch.distributed as dist
from torch.autograd import Variable
__all__: List[str] = []
__all__: list[str] = []
_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param"

View File

@ -2,7 +2,6 @@
import logging
import math
from collections import defaultdict
from typing import Dict
import torch
import torch.distributed as dist
@ -252,9 +251,9 @@ class PowerSGDState:
self.rng = np.random.RandomState(random_seed)
# Since there is only a single state instance for all the input buckets,
# need to maintain a dictionary that maps each bucket index to the local error.
self.error_dict: Dict[int, torch.Tensor] = {}
self.p_memory_dict: Dict[int, torch.Tensor] = {}
self.q_memory_dict: Dict[int, torch.Tensor] = {}
self.error_dict: dict[int, torch.Tensor] = {}
self.p_memory_dict: dict[int, torch.Tensor] = {}
self.q_memory_dict: dict[int, torch.Tensor] = {}
# Iteration/step in the training loop.
self.iter = 0
# Compression stats accumulators

View File

@ -2,7 +2,7 @@
import warnings
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, List, NamedTuple, Optional, Type
from typing import Any, NamedTuple, Optional
import torch
import torch.distributed as dist
@ -165,7 +165,7 @@ class Join:
def __init__(
self,
joinables: List[Joinable],
joinables: list[Joinable],
enable: bool = True,
throw_on_early_termination: bool = False,
**kwargs,
@ -228,7 +228,7 @@ class Join:
def __exit__(
self,
type: Optional[Type[BaseException]],
type: Optional[type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
):

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Optional, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.distributed as dist
@ -107,7 +108,7 @@ class PeriodicModelAverager(ModelAverager):
def average_parameters(
self,
params: Union[
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
],
):
"""

View File

@ -3,7 +3,8 @@
import logging
import warnings
from collections import OrderedDict
from typing import Dict, Iterable, Union
from collections.abc import Iterable
from typing import Union
import torch
import torch.distributed as dist
@ -159,7 +160,7 @@ class HierarchicalModelAverager(averagers.ModelAverager):
def average_parameters(
self,
params: Union[
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
],
):
"""

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
# flake8: noqa C101
import itertools
from typing import Dict, Iterable, Iterator, Union
from collections.abc import Iterable, Iterator
from typing import Union
import torch
import torch.distributed as dist
@ -51,7 +52,7 @@ def average_parameters(
def get_params_to_average(
params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]
params: Union[Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]]
):
"""
Return a list of parameters that need to average.
@ -81,7 +82,7 @@ def get_params_to_average(
def average_parameters_or_parameter_groups(
params: Union[
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
],
process_group: ProcessGroup,
):

View File

@ -9,7 +9,7 @@
import functools
import logging
from typing import Any, Callable, Dict, List, TypeVar
from typing import Any, Callable, TypeVar
from typing_extensions import ParamSpec
import torch
@ -18,7 +18,7 @@ from torch.distributed.logging_handlers import _log_handlers
from torch.monitor import _WaitCounter
__all__: List[str] = []
__all__: list[str] = []
_DEFAULT_DESTINATION = "default"
@ -48,7 +48,7 @@ global _c10d_logger
_c10d_logger = _get_or_create_logger()
def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]:
if dist.is_initialized():
group = kwargs.get("group") or kwargs.get("process_group")
msg_dict = {

View File

@ -10,7 +10,7 @@ Each should also handle single rank scenario.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, cast, Generic, List, Optional, TypeVar, Union
from typing import Any, Callable, cast, Generic, Optional, TypeVar, Union
import torch.distributed as dist
@ -106,7 +106,7 @@ def all_gather(
data_or_fn: Union[T, Callable[[], T]],
stage_name: Optional[str] = None,
pg: Optional[dist.ProcessGroup] = None,
) -> List[T]:
) -> list[T]:
"""
A simple all_gather primitive with basic synchronization guard logic,
by checking payload from all ranks has the same stage name.
@ -149,11 +149,11 @@ def all_gather(
all_gather_object_enforce_type(pg, total_list, sync_obj)
# Each rank will throw RuntimeError in case of failure on any rank.
stage_name = cast(SyncPayload[T], total_list[0]).stage_name
exception_list: List[tuple[int, Exception]] = []
ret_list: List[T] = []
exception_list: list[tuple[int, Exception]] = []
ret_list: list[T] = []
error_msg: str = ""
for i, sp in enumerate(cast(List[SyncPayload[T]], total_list)):
for i, sp in enumerate(cast(list[SyncPayload[T]], total_list)):
if sp.stage_name != stage_name:
error_msg += (
f"Unexpected stage name received from rank {i}: {sp.stage_name} "
@ -183,7 +183,7 @@ def all_gather(
def all_gather_object_enforce_type(
pg: dist.ProcessGroup,
# pyre-fixme[2]: Parameter must have a type that does not contain `Any`
object_list: List[Any],
object_list: list[Any],
# pyre-fixme[2]: Parameter must have a type other than `Any`
obj: Any,
# pyre-fixme[2]: Parameter must have a type that does not contain `Any`

View File

@ -5,7 +5,7 @@ import math
import threading
from functools import reduce
from itertools import chain
from typing import Dict, List, Optional, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
import torch
from torch.distributed import is_available
@ -65,15 +65,15 @@ else:
class _MeshEnv(threading.local):
def __init__(self) -> None:
self.mesh_stack: List[DeviceMesh] = []
self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {}
self.mesh_dim_group_options: Dict[
self.mesh_stack: list[DeviceMesh] = []
self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {}
self.mesh_dim_group_options: dict[
int, tuple[str, Optional[C10dBackend.Options]]
] = {}
self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {}
self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {}
# Record flatten mesh name to its mesh dim index in root mesh.
self.flatten_name_to_root_dims: Dict[
DeviceMesh, Dict[str, tuple[int, ...]]
self.flatten_name_to_root_dims: dict[
DeviceMesh, dict[str, tuple[int, ...]]
] = {}
def get_current_mesh(self) -> "DeviceMesh":
@ -85,7 +85,7 @@ else:
self,
device_mesh: "DeviceMesh",
submesh_dim_names: tuple[str, ...],
submesh_dims: List[tuple[int, ...]],
submesh_dims: list[tuple[int, ...]],
) -> "DeviceMesh":
# Get the submesh dim size from the submesh_dims.
# For example, if we have a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp", "cp", "tp") and we want
@ -286,7 +286,7 @@ else:
def _get_slice_mesh_dims(
self, device_mesh, mesh_dim_names
) -> List[tuple[int, ...]]:
) -> list[tuple[int, ...]]:
"""
Validate whether the mesh_dim_names is valid for slicing the given device_mesh.
If valid, return dim indexes of the slice mesh in the device mesh.
@ -338,7 +338,7 @@ else:
def _get_all_submeshes(
self, device_mesh: "DeviceMesh", mesh_dim_name: str
) -> List["DeviceMesh"]:
) -> list["DeviceMesh"]:
"""
Return all the submeshes of a given mesh dimension of the device mesh.
"""
@ -458,7 +458,7 @@ else:
# calculate the coordinates of the current global rank on the mesh
rank_coords = (self.mesh == get_rank()).nonzero()
assert rank_coords.size(0) in (0, 1)
self._coordinate_on_dim: Optional[List[int]] = (
self._coordinate_on_dim: Optional[list[int]] = (
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
@ -498,7 +498,7 @@ else:
# TODO(yifu): remove tag and ranks once we fully migrate to native
# functional collectives. See details in:
# https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
dim_group_infos: List[tuple[str, List[int], str]] = []
dim_group_infos: list[tuple[str, list[int], str]] = []
default_group = _get_default_group()
if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
@ -775,7 +775,7 @@ else:
_find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) # type: ignore[index]
)
def get_all_groups(self) -> List[ProcessGroup]:
def get_all_groups(self) -> list[ProcessGroup]:
"""
Returns a list of ProcessGroups for all mesh dimensions.
@ -786,7 +786,7 @@ else:
@staticmethod
def from_group(
group: Union[ProcessGroup, List[ProcessGroup]],
group: Union[ProcessGroup, list[ProcessGroup]],
device_type: str,
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
*,
@ -910,7 +910,7 @@ else:
), "We expect ProcessGroup before calling `get_rank`!"
return not_none(get_rank(mesh_dim_group))
def get_coordinate(self) -> Optional[List[int]]:
def get_coordinate(self) -> Optional[list[int]]:
"""
Return the relative indices of this rank relative to all
dimensions of the mesh. If this rank is not part of the mesh, return None.

View File

@ -15,7 +15,7 @@ import time
import warnings
from collections import namedtuple
from datetime import timedelta
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import deprecated
import torch
@ -259,18 +259,18 @@ class Backend(str):
_BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"])
_plugins: Dict[str, _BackendPlugin] = {}
_plugins: dict[str, _BackendPlugin] = {}
backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI]
# 3rd-party devices can register the default backend support here
default_device_backend_map: Dict[str, str] = {
default_device_backend_map: dict[str, str] = {
"cpu": GLOO,
"cuda": NCCL,
"xpu": XCCL,
}
backend_capability: Dict[str, List[str]] = {
backend_capability: dict[str, list[str]] = {
GLOO: ["cpu", "cuda"],
NCCL: ["cuda"],
XCCL: ["xpu"],
@ -278,7 +278,7 @@ class Backend(str):
MPI: ["cpu", "cuda"],
}
backend_type_map: Dict[str, ProcessGroup.BackendType] = {
backend_type_map: dict[str, ProcessGroup.BackendType] = {
UNDEFINED: ProcessGroup.BackendType.UNDEFINED,
GLOO: ProcessGroup.BackendType.GLOO,
NCCL: ProcessGroup.BackendType.NCCL,
@ -303,7 +303,7 @@ class Backend(str):
name,
func,
extended_api=False,
devices: Optional[Union[str, List[str]]] = None,
devices: Optional[Union[str, list[str]]] = None,
) -> None:
"""
Register a new backend with the given name and instantiating function.
@ -371,7 +371,7 @@ class BackendConfig:
def __init__(self, backend: Backend):
"""Init."""
self.device_backend_map: Dict[str, Backend] = {}
self.device_backend_map: dict[str, Backend] = {}
backend = str(backend)
if backend == Backend.UNDEFINED:
@ -441,7 +441,7 @@ class BackendConfig:
f"{device}:{backend}" for device, backend in self.device_backend_map.items()
)
def get_device_backend_map(self) -> Dict[str, Backend]:
def get_device_backend_map(self) -> dict[str, Backend]:
"""Return backend map of the device."""
return self.device_backend_map
@ -572,14 +572,14 @@ class _CollOp:
# DO NOT USE THESE FIELDS DIRECTLY.
# Use them through the _world object to make sure the _world override mechanism
_pg_map: Dict[ProcessGroup, tuple[str, Store]] = {}
_pg_names: Dict[ProcessGroup, str] = {}
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
_pg_map: dict[ProcessGroup, tuple[str, Store]] = {}
_pg_names: dict[ProcessGroup, str] = {}
_pg_group_ranks: dict[ProcessGroup, dict[int, int]] = {}
# For a pg, it is a map from ProcessGroup to BackendConfig
_pg_backend_config: Dict[ProcessGroup, str] = {}
_pg_backend_config: dict[ProcessGroup, str] = {}
_group_count = 0
_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
_pg_to_tag: Dict[ProcessGroup, str] = {}
_tags_to_pg: dict[str, list[ProcessGroup]] = {}
_pg_to_tag: dict[ProcessGroup, str] = {}
_backend: Optional[str] = None
@ -595,7 +595,7 @@ class _World:
def __init__(self) -> None:
self._default_pg = None
self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {}
self._pg_coalesce_state: dict[ProcessGroup, list[_CollOp]] = {}
@property
def default_pg(self) -> Optional[ProcessGroup]:
@ -612,7 +612,7 @@ class _World:
self._default_pg = value
@property
def pg_map(self) -> Dict[ProcessGroup, tuple[str, Store]]:
def pg_map(self) -> dict[ProcessGroup, tuple[str, Store]]:
"""
Provide Mapping from ProcessGroup to backend name and store.
@ -625,7 +625,7 @@ class _World:
return _pg_map
@property
def pg_names(self) -> Dict[ProcessGroup, str]:
def pg_names(self) -> dict[ProcessGroup, str]:
"""
Process group's names, map from ProcessGroup to str.
@ -635,7 +635,7 @@ class _World:
return _pg_names
@property
def pg_group_ranks(self) -> Dict[ProcessGroup, Dict[int, int]]:
def pg_group_ranks(self) -> dict[ProcessGroup, dict[int, int]]:
"""
Process group's global rank to local rank mapping.
@ -645,7 +645,7 @@ class _World:
return _pg_group_ranks
@property
def pg_backend_config(self) -> Dict[ProcessGroup, str]:
def pg_backend_config(self) -> dict[ProcessGroup, str]:
"""
Process group's backend config.
@ -671,27 +671,27 @@ class _World:
_group_count = value
@property
def tags_to_pg(self) -> Dict[str, List[ProcessGroup]]:
def tags_to_pg(self) -> dict[str, list[ProcessGroup]]:
global _tags_to_pg
return _tags_to_pg
@property
def pg_to_tag(self) -> Dict[ProcessGroup, str]:
def pg_to_tag(self) -> dict[ProcessGroup, str]:
global _pg_to_tag
return _pg_to_tag
@property
def pg_coalesce_state(self) -> Dict[ProcessGroup, List[_CollOp]]:
def pg_coalesce_state(self) -> dict[ProcessGroup, list[_CollOp]]:
return self._pg_coalesce_state
@property
def pg_config_info(self) -> List[Dict[str, Any]]:
def pg_config_info(self) -> list[dict[str, Any]]:
"""
Return a list of dict with process groups and backends.
Along with their unique IDs and configurations (types and ranks).
"""
config_info: List[Dict[str, Any]] = []
config_info: list[dict[str, Any]] = []
default_pg_size = _get_group_size(None)
for pg in self.pg_map.keys():
ranks = self.pg_group_ranks[pg]
@ -912,7 +912,7 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device
return rv
def _device_capability(group: Optional[ProcessGroup] = None) -> List[str]:
def _device_capability(group: Optional[ProcessGroup] = None) -> list[str]:
"""
Return the device type(s) supported by ``group``.
@ -1077,7 +1077,7 @@ def _get_global_rank(group, rank) -> int:
return get_global_rank(group, rank)
def get_process_group_ranks(group: ProcessGroup) -> List[int]:
def get_process_group_ranks(group: ProcessGroup) -> list[int]:
"""
Get all ranks associated with ``group``.
@ -1103,7 +1103,7 @@ def _get_group_size_by_name(group_name: str) -> int:
return group.size()
def _resolve_group_name_by_ranks_and_tag(ranks: List[int], tag: str) -> str:
def _resolve_group_name_by_ranks_and_tag(ranks: list[int], tag: str) -> str:
# TODO(yifu): remove this function once ranks + tag is not a supported
# identifier for process group for functional collectives.
group = _find_pg_by_ranks_and_tag(tag, ranks)
@ -1401,7 +1401,7 @@ def _get_process_group_uid(pg: ProcessGroup) -> int:
return -1
def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]:
def _get_pg_config(group: Optional[ProcessGroup] = None) -> dict[str, Any]:
"""
Return the pg configuration of the given process group.
@ -1416,12 +1416,12 @@ def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]:
}
def _get_all_pg_configs() -> List[Dict[str, Any]]:
def _get_all_pg_configs() -> list[dict[str, Any]]:
"""
Return the pg configuration of all the process groups.
"""
config_info: List[Dict[str, Any]] = [
config_info: list[dict[str, Any]] = [
_get_pg_config(pg) for pg in _world.pg_map.keys()
]
return config_info
@ -2507,7 +2507,7 @@ class _IllegalWork(Work):
class _CoalescingManager:
def __init__(self) -> None:
self.works: List[Work] = []
self.works: list[Work] = []
def append(self, work: Work):
if work:
@ -2607,7 +2607,7 @@ def _coalescing_manager(
work.wait() # type: ignore[possibly-undefined]
def batch_isend_irecv(p2p_op_list: List[P2POp]) -> List[Work]:
def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
"""
Send or Receive a batch of tensors asynchronously and return a list of requests.
@ -2651,7 +2651,7 @@ def batch_isend_irecv(p2p_op_list: List[P2POp]) -> List[Work]:
group = p2p_op_list[0].group
device = p2p_op_list[0].tensor.device
def peer_kwarg(op: P2POp) -> Dict[str, int]:
def peer_kwarg(op: P2POp) -> dict[str, int]:
key = "group_dst" if op.op == isend else "group_src"
return {key: op.group_peer}
@ -3054,7 +3054,7 @@ def all_gather_object(object_list, obj, group=None):
@_exception_logger
def gather_object(
obj: Any,
object_gather_list: Optional[List[Any]] = None,
object_gather_list: Optional[list[Any]] = None,
dst: Optional[int] = None,
group: Optional[ProcessGroup] = None,
group_dst: Optional[int] = None,
@ -3178,7 +3178,7 @@ def gather_object(
@_exception_logger
def send_object_list(
object_list: List[Any],
object_list: list[Any],
dst: Optional[int] = None,
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
@ -3277,7 +3277,7 @@ def send_object_list(
@_exception_logger
def recv_object_list(
object_list: List[Any],
object_list: list[Any],
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
@ -3379,7 +3379,7 @@ def recv_object_list(
@_exception_logger
def broadcast_object_list(
object_list: List[Any],
object_list: list[Any],
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None,
@ -3505,8 +3505,8 @@ def broadcast_object_list(
@_exception_logger
def scatter_object_list(
scatter_object_output_list: List[Any],
scatter_object_input_list: Optional[List[Any]] = None,
scatter_object_output_list: list[Any],
scatter_object_input_list: Optional[list[Any]] = None,
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
group_src: Optional[int] = None,
@ -3933,7 +3933,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
@_exception_logger
def gather(
tensor: torch.Tensor,
gather_list: Optional[List[torch.Tensor]] = None,
gather_list: Optional[list[torch.Tensor]] = None,
dst: Optional[int] = None,
group: Optional[ProcessGroup] = None,
async_op: bool = False,
@ -4013,7 +4013,7 @@ def gather(
@_exception_logger
def scatter(
tensor: torch.Tensor,
scatter_list: Optional[List[torch.Tensor]] = None,
scatter_list: Optional[list[torch.Tensor]] = None,
src: Optional[int] = None,
group: Optional[ProcessGroup] = None,
async_op: bool = False,
@ -4657,7 +4657,7 @@ def _create_process_group_wrapper(
# helper function for deterministically hashing a list of ranks to a unique
# string
def _hash_ranks_to_str(ranks: List[int]) -> str:
def _hash_ranks_to_str(ranks: list[int]) -> str:
rank_join: str = "_".join(map(str, ranks))
# In case there is already a PG with the same rank composition
unique_str = "_".join([rank_join, str(len(_world.pg_names))])
@ -4665,7 +4665,7 @@ def _hash_ranks_to_str(ranks: List[int]) -> str:
# Takes a list of ranks and computes an integer color
def _process_group_color(ranks: List[int]) -> int:
def _process_group_color(ranks: list[int]) -> int:
# Convert list to tuple to make it hashable
ranks = tuple(ranks)
hash_value = hash(ranks)
@ -5315,7 +5315,7 @@ def new_subgroups_by_enumeration(
return cur_subgroup, subgroups
def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGroup]:
def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGroup]:
if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"):
tag = f"user:{tag}"
@ -5331,7 +5331,7 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGro
def _find_or_create_pg_by_ranks_and_tag(
tag: str, ranks: List[int], stride: int
tag: str, ranks: list[int], stride: int
) -> ProcessGroup:
assert (
len(ranks) % stride == 0

View File

@ -9,7 +9,7 @@
import sys
import uuid
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Optional, Union
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
from torch.distributed.elastic import events, metrics
@ -79,13 +79,13 @@ class LaunchConfig:
role: str = "default_role"
rdzv_endpoint: str = ""
rdzv_backend: str = "etcd"
rdzv_configs: Dict[str, Any] = field(default_factory=dict)
rdzv_configs: dict[str, Any] = field(default_factory=dict)
rdzv_timeout: int = -1
max_restarts: int = 3
monitor_interval: float = 0.1
start_method: str = "spawn"
log_line_prefix_template: Optional[str] = None
metrics_cfg: Dict[str, str] = field(default_factory=dict)
metrics_cfg: dict[str, str] = field(default_factory=dict)
local_addr: Optional[str] = None
def __post_init__(self):
@ -140,7 +140,7 @@ class elastic_launch:
def _get_entrypoint_name(
entrypoint: Union[Callable, str, None], args: List[Any]
entrypoint: Union[Callable, str, None], args: list[Any]
) -> str:
"""Retrieve entrypoint name with the rule:
1. If entrypoint is a function, use ``entrypoint.__qualname__``.
@ -183,8 +183,8 @@ def _get_addr_and_port(
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
args: List[Any],
) -> Dict[int, Any]:
args: list[Any],
) -> dict[int, Any]:
if not config.run_id:
run_id = str(uuid.uuid4().int)
logger.warning("config has no run_id, generated a random run_id: %s", run_id)

View File

@ -7,11 +7,10 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, List
__all__: List[str] = []
__all__: list[str] = []
_log_handlers: Dict[str, logging.Handler] = {
_log_handlers: dict[str, logging.Handler] = {
"default": logging.NullHandler(),
}

View File

@ -4,20 +4,8 @@ import collections
import io
import sys
import types
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
from collections.abc import Iterator, Mapping
from typing import Any, Callable, Optional, TypeVar, Union
import torch
import torch.distributed.rpc as rpc
@ -109,8 +97,8 @@ def _create_module_with_interface(
return rpc.RRef(module, module_interface_cls)
def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]:
ret: List[rpc.RRef[Parameter]] = [
def _param_rrefs(module_rref, recurse) -> list[rpc.RRef[Parameter]]:
ret: list[rpc.RRef[Parameter]] = [
rpc.RRef(param) for param in module_rref.local_value().parameters(recurse)
]
return ret
@ -129,9 +117,9 @@ class _RemoteModule(nn.Module):
def __init__(
self,
remote_device: str,
module_cls: Type[nn.Module],
args: Optional[Tuple] = None,
kwargs: Optional[Dict[str, Any]] = None,
module_cls: type[nn.Module],
args: Optional[tuple] = None,
kwargs: Optional[dict[str, Any]] = None,
_module_interface_cls: Any = None,
):
"""
@ -282,7 +270,7 @@ class _RemoteModule(nn.Module):
self._install_generated_methods()
self._check_attribute_picklability()
def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]:
def remote_parameters(self, recurse: bool = True) -> list[rpc.RRef[Parameter]]:
"""
Return a list of :class:`~torch.distributed.rpc.RRef` pointing to the remote module's parameters.
@ -371,8 +359,8 @@ class _RemoteModule(nn.Module):
hook: Union[
Callable[[T, tuple[Any, ...]], Optional[Any]],
Callable[
[T, tuple[Any, ...], Dict[str, Any]],
Optional[tuple[Any, Dict[str, Any]]],
[T, tuple[Any, ...], dict[str, Any]],
Optional[tuple[Any, dict[str, Any]]],
],
],
prepend: bool = False,
@ -384,7 +372,7 @@ class _RemoteModule(nn.Module):
self,
hook: Union[
Callable[[T, tuple[Any, ...], Any], Optional[Any]],
Callable[[T, tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]],
],
prepend: bool = False,
with_kwargs: bool = False,
@ -431,7 +419,7 @@ class _RemoteModule(nn.Module):
def named_modules(
self,
memo: Optional[Set[Module]] = None,
memo: Optional[set[Module]] = None,
prefix: str = "",
remove_duplicate: bool = True,
):
@ -681,9 +669,9 @@ class RemoteModule(_RemoteModule):
def __init__(
self,
remote_device: str,
module_cls: Type[nn.Module],
args: Optional[Tuple] = None,
kwargs: Optional[Dict[str, Any]] = None,
module_cls: type[nn.Module],
args: Optional[tuple] = None,
kwargs: Optional[dict[str, Any]] = None,
):
super().__init__(remote_device, module_cls, args, kwargs)

View File

@ -1,9 +1,10 @@
from typing import Any, Dict, Iterable, List, no_type_check, Type
from collections.abc import Iterable
from typing import Any, no_type_check
import torch
__all__: List[str] = []
__all__: list[str] = []
# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter
# without changing it's life-time.
@ -15,9 +16,9 @@ param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary()
@no_type_check
def _apply_optimizer_in_backward(
optimizer_class: Type[torch.optim.Optimizer],
optimizer_class: type[torch.optim.Optimizer],
params: Iterable[torch.nn.Parameter],
optimizer_kwargs: Dict[str, Any],
optimizer_kwargs: dict[str, Any],
register_hook: bool = True,
) -> None:
"""
@ -97,7 +98,7 @@ def _apply_optimizer_in_backward(
_apply_optimizer_in_backward_to_param(param)
def _get_in_backward_optimizers(module: torch.nn.Module) -> List[torch.optim.Optimizer]:
def _get_in_backward_optimizers(module: torch.nn.Module) -> list[torch.optim.Optimizer]:
"""
Return a list of in-backward optimizers applied to ``module``'s parameters. Note that these
optimizers are not intended to directly have their ``step`` or ``zero_grad`` methods called
@ -113,7 +114,7 @@ def _get_in_backward_optimizers(module: torch.nn.Module) -> List[torch.optim.Opt
_apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {'lr': 0.01})
optims = _get_optimizers_in_backward(model)
"""
optims: List[torch.optim.Optimizer] = []
optims: list[torch.optim.Optimizer] = []
for param in module.parameters():
optims.extend(getattr(param, "_in_backward_optimizers", []))

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional
from typing import Optional
import torch
import torch.optim._functional as F
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
)
__all__: List[str] = []
__all__: list[str] = []
# Define a TorchScript compatible Functional Adadelta Optimizer
@ -25,7 +25,7 @@ __all__: List[str] = []
class _FunctionalAdadelta:
def __init__(
self,
params: List[Tensor],
params: list[Tensor],
lr: float = 1.0,
rho: float = 0.9,
eps: float = 1e-6,
@ -51,9 +51,9 @@ class _FunctionalAdadelta:
# param group as it's not a common use case.
self.param_group = {"params": params}
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
def step(self, gradients: List[Optional[Tensor]]):
def step(self, gradients: list[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional
from typing import Optional
import torch
import torch.optim._functional as F
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
)
__all__: List[str] = []
__all__: list[str] = []
# Define a TorchScript compatible Functional Adagrad Optimizer
@ -25,7 +25,7 @@ __all__: List[str] = []
class _FunctionalAdagrad:
def __init__(
self,
params: List[Tensor],
params: list[Tensor],
lr: float = 1e-2,
lr_decay: float = 0.0,
weight_decay: float = 0.0,
@ -53,7 +53,7 @@ class _FunctionalAdagrad:
self.foreach = foreach
self.fused = fused
self.maximize = maximize
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
@ -70,12 +70,12 @@ class _FunctionalAdagrad:
"step": torch.tensor(0.0),
}
def step(self, gradients: List[Optional[Tensor]]):
def step(self, gradients: list[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
state_sums = []
state_steps: List[Tensor] = []
state_steps: list[Tensor] = []
if len(params) != len(gradients):
raise ValueError(

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional
from typing import Optional
import torch
import torch.optim._functional as F
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
)
__all__: List[str] = []
__all__: list[str] = []
# Define a TorchScript compatible Functional Adam Optimizer
@ -25,7 +25,7 @@ __all__: List[str] = []
class _FunctionalAdam:
def __init__(
self,
params: List[Tensor],
params: list[Tensor],
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
@ -59,7 +59,7 @@ class _FunctionalAdam:
self.maximize = maximize
self.foreach = foreach
self.fused = fused
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
@ -78,7 +78,7 @@ class _FunctionalAdam:
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps: List[Tensor] = []
state_steps: list[Tensor] = []
has_complex = torch.is_complex(param)
if grad is not None:
params_with_grad.append(param)
@ -128,14 +128,14 @@ class _FunctionalAdam:
found_inf=None,
)
def step(self, gradients: List[Optional[Tensor]]):
def step(self, gradients: list[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps: List[Tensor] = []
state_steps: list[Tensor] = []
has_complex = False
if len(params) != len(gradients):

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional
from typing import Optional
import torch
import torch.optim._functional as F
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
)
__all__: List[str] = []
__all__: list[str] = []
# Define a TorchScript compatible Functional Adamax Optimizer
@ -25,7 +25,7 @@ __all__: List[str] = []
class _FunctionalAdamax:
def __init__(
self,
params: List[Tensor],
params: list[Tensor],
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
@ -55,7 +55,7 @@ class _FunctionalAdamax:
}
self.foreach = foreach
self.maximize = maximize
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
@ -64,13 +64,13 @@ class _FunctionalAdamax:
# param group as it's not a common use case.
self.param_group = {"params": params}
def step(self, gradients: List[Optional[Tensor]]):
def step(self, gradients: list[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
exp_avgs = []
exp_infs = []
state_steps: List[Tensor] = []
state_steps: list[Tensor] = []
if len(params) != len(gradients):
raise ValueError(

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional
from typing import Optional
import torch
import torch.optim._functional as F
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
)
__all__: List[str] = []
__all__: list[str] = []
# Define a TorchScript compatible Functional AdamW Optimizer
@ -25,7 +25,7 @@ __all__: List[str] = []
class _FunctionalAdamW:
def __init__(
self,
params: List[Tensor],
params: list[Tensor],
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
@ -59,7 +59,7 @@ class _FunctionalAdamW:
self.maximize = maximize
self.foreach = foreach
self.fused = fused
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
@ -74,7 +74,7 @@ class _FunctionalAdamW:
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps: List[Tensor] = []
state_steps: list[Tensor] = []
has_complex = torch.is_complex(param)
if grad is not None:
params_with_grad.append(param)
@ -129,14 +129,14 @@ class _FunctionalAdamW:
has_complex=has_complex,
)
def step(self, gradients: List[Optional[Tensor]]):
def step(self, gradients: list[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps: List[Tensor] = []
state_steps: list[Tensor] = []
if len(params) != len(gradients):
raise ValueError(

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional
from typing import Optional
import torch
import torch.optim._functional as F
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
)
__all__: List[str] = []
__all__: list[str] = []
# Define a TorchScript compatible Functional RMSprop Optimizer
@ -25,7 +25,7 @@ __all__: List[str] = []
class _FunctionalRMSprop:
def __init__(
self,
params: List[Tensor],
params: list[Tensor],
lr: float = 1e-2,
alpha: float = 0.99,
eps: float = 1e-8,
@ -55,9 +55,9 @@ class _FunctionalRMSprop:
# param group as it's not a common use case.
self.param_group = {"params": params}
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
def step(self, gradients: List[Optional[Tensor]]):
def step(self, gradients: list[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional
from typing import Optional
import torch
import torch.optim._functional as F
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
)
__all__: List[str] = []
__all__: list[str] = []
# Define a TorchScript compatible Functional Rprop Optimizer
@ -25,7 +25,7 @@ __all__: List[str] = []
class _FunctionalRprop:
def __init__(
self,
params: List[Tensor],
params: list[Tensor],
lr: float = 1e-2,
etas: tuple[float, float] = (0.5, 1.2),
step_sizes: tuple[float, float] = (1e-6, 50),
@ -49,9 +49,9 @@ class _FunctionalRprop:
# param group as it's not a common use case.
self.param_group = {"params": params}
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
def step(self, gradients: List[Optional[Tensor]]):
def step(self, gradients: list[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional
from typing import Optional
import torch
import torch.optim._functional as F
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
)
__all__: List[str] = []
__all__: list[str] = []
# Define a TorchScript compatible Functional SGD Optimizer
@ -25,7 +25,7 @@ __all__: List[str] = []
class _FunctionalSGD:
def __init__(
self,
params: List[Tensor],
params: list[Tensor],
lr: float = 1e-2,
momentum: float = 0.0,
dampening: float = 0.0,
@ -47,7 +47,7 @@ class _FunctionalSGD:
self.maximize = maximize
self.foreach = foreach
self.fused = fused
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
raise ValueError("optimizer got an empty parameter list")
@ -67,7 +67,7 @@ class _FunctionalSGD:
dampening = self.defaults["dampening"]
lr = self.defaults["lr"]
params = [param]
momentum_buffer_list: List[Optional[Tensor]] = []
momentum_buffer_list: list[Optional[Tensor]] = []
grads = []
has_sparse_grad = False
@ -106,11 +106,11 @@ class _FunctionalSGD:
if momentum_buffer is not None:
state["momentum_buffer"] = momentum_buffer
def step(self, gradients: List[Optional[Tensor]]):
def step(self, gradients: list[Optional[Tensor]]):
params = self.param_group["params"]
params_with_grad = []
grads = []
momentum_buffer_list: List[Optional[Tensor]] = []
momentum_buffer_list: list[Optional[Tensor]] = []
lr = self.defaults["lr"]
weight_decay = self.defaults["weight_decay"]
momentum = self.defaults["momentum"]

View File

@ -1,18 +1,9 @@
# mypy: allow-untyped-defs
import logging
import warnings
from collections.abc import Collection, Mapping
from copy import deepcopy
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Mapping,
Optional,
overload,
Union,
)
from typing import Any, Callable, Optional, overload, Union
import torch
import torch.nn as nn
@ -21,7 +12,7 @@ from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
__all__: List[str] = []
__all__: list[str] = []
logger = logging.getLogger(__name__)
@ -129,7 +120,7 @@ class _NamedOptimizer(optim.Optimizer):
)
param_group["params"] = params
def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> dict[str, Any]:
"""
Return the ``state_dict`` of the optimizer.
@ -317,7 +308,7 @@ class _NamedOptimizer(optim.Optimizer):
# Calling ``step`` will load the initial state for optimizer states.
self.step(closure=None)
def _pre_load_state_dict(self, state_dict) -> Dict[str, Any]:
def _pre_load_state_dict(self, state_dict) -> dict[str, Any]:
# TODO(chienchin): This API should be FSDP agnostic and should support
# general user hooks.
if isinstance(self.module, FSDP):
@ -326,7 +317,7 @@ class _NamedOptimizer(optim.Optimizer):
)
return state_dict
def _post_state_dict(self, state_dict) -> Dict[str, Any]:
def _post_state_dict(self, state_dict) -> dict[str, Any]:
# TODO(chienchin): This API should be FSDP agnostic and should support
# general user hooks.
if isinstance(self.module, FSDP):
@ -334,6 +325,6 @@ class _NamedOptimizer(optim.Optimizer):
return state_dict
def _gen_param_group_key(param_keys: List[str]) -> str:
def _gen_param_group_key(param_keys: list[str]) -> str:
"""Concatenate all param keys as a unique indentifier for one param group."""
return "/".join(sorted(param_keys))

View File

@ -3,7 +3,7 @@
import logging
from collections import defaultdict
from threading import Lock
from typing import List, Optional
from typing import Optional
import torch
import torch.distributed.autograd as dist_autograd
@ -52,7 +52,7 @@ class _ScriptLocalOptimizer(nn.Module):
def step(self, autograd_ctx_id: int):
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
# apply functional optimizer step with a list of gradients
grads: List[Optional[Tensor]] = [
grads: list[Optional[Tensor]] = [
all_local_grads[p] if p in all_local_grads else None
for p in self._local_params
]

View File

@ -1,5 +1,4 @@
# mypy: allow-untyped-defs
from typing import Type
from torch import optim
@ -46,7 +45,7 @@ def register_functional_optim(key, optim):
functional_optim_map[key] = optim
def as_functional_optim(optim_cls: Type, *args, **kwargs):
def as_functional_optim(optim_cls: type, *args, **kwargs):
try:
functional_cls = functional_optim_map[optim_cls]
except KeyError as e:
@ -57,7 +56,7 @@ def as_functional_optim(optim_cls: Type, *args, **kwargs):
return _create_functional_optim(functional_cls, *args, **kwargs)
def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
def _create_functional_optim(functional_optim_cls: type, *args, **kwargs):
return functional_optim_cls(
[],
*args,

View File

@ -11,7 +11,7 @@ import inspect
import io
import logging
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.distributed as dist
@ -156,7 +156,7 @@ class _DDPBucketAssignment:
def __init__(
self,
bucket_index: int,
parameters: List[torch.Tensor],
parameters: list[torch.Tensor],
offset: int,
):
self.bucket_index = bucket_index
@ -239,20 +239,20 @@ class _OverlapInfo:
self.shard_buckets: bool = False
# Modified per bucket reconstruction
self.params_per_bucket: List[List[torch.Tensor]] = []
self.params_per_rank: List[List[torch.Tensor]] = [[] for _ in range(world_size)]
self.offsets: Dict[int, int] = {}
self.params_per_bucket: list[list[torch.Tensor]] = []
self.params_per_rank: list[list[torch.Tensor]] = [[] for _ in range(world_size)]
self.offsets: dict[int, int] = {}
# Group Ranks
self.assigned_ranks_per_bucket: List[Set[int]] = []
self.assigned_ranks_per_bucket: list[set[int]] = []
self.num_bucket_assignments: int = 0
self.total_size: Optional[int] = None
# Modified per iteration
self.broadcast_handles: List[Any] = []
self.bucket_indices_seen: List[int] = []
self.broadcast_handles: list[Any] = []
self.bucket_indices_seen: list[int] = []
# Used by `hook_with_zero_step()`
self.bucket_index_to_future: Dict[int, torch.futures.Future] = {}
self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {}
self.bucket_index_to_future: dict[int, torch.futures.Future] = {}
self.bucket_index_to_bucket: dict[int, dist.GradBucket] = {}
def wait_for_broadcasts(self) -> None:
r"""
@ -372,7 +372,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
def __init__(
self,
params,
optimizer_class: Type[Optimizer],
optimizer_class: type[Optimizer],
process_group: Optional[Any] = None,
parameters_as_bucket_view: bool = False,
overlap_with_ddp: bool = False,
@ -395,15 +395,15 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
# `self.param_groups`
# Internal data structures (`_cache` indicates lazily evaluated)
self._param_to_rank_cache: Dict[torch.Tensor, int] = {}
self._param_to_index_cache: Dict[torch.Tensor, int] = {}
self._partition_parameters_cache: List[List[Dict]] = []
self._index_to_param_cache: List[torch.Tensor] = []
self._device_to_params_per_rank_cache: Dict[
torch.device, List[List[torch.Tensor]]
self._param_to_rank_cache: dict[torch.Tensor, int] = {}
self._param_to_index_cache: dict[torch.Tensor, int] = {}
self._partition_parameters_cache: list[list[dict]] = []
self._index_to_param_cache: list[torch.Tensor] = []
self._device_to_params_per_rank_cache: dict[
torch.device, list[list[torch.Tensor]]
] = {}
self._bucket_assignments_per_rank_cache: List[
Dict[int, _DDPBucketAssignment]
self._bucket_assignments_per_rank_cache: list[
dict[int, _DDPBucketAssignment]
] = []
self._is_trainable_mask = self._get_is_trainable_mask()
@ -439,12 +439,12 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
# `self._buckets` is used if `parameters_as_bucket_view=True`, in
# which case parameter data is flattened into contiguous bucket tensors
self.parameters_as_bucket_view = parameters_as_bucket_view
self._buckets: List[List[torch.Tensor]] = []
self._buckets: list[list[torch.Tensor]] = []
self._build_param_buckets()
# Optional consolidated optimizer state, only populated if this rank
# is the target in `consolidate_state_dict()`
self._all_state_dicts: List[Dict[str, Any]] = []
self._all_state_dicts: list[dict[str, Any]] = []
self.initialized = True
@ -457,7 +457,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
self._device_to_params_per_rank_cache.clear()
self._bucket_assignments_per_rank_cache.clear()
def add_param_group(self, param_group: Dict[str, Any]) -> None:
def add_param_group(self, param_group: dict[str, Any]) -> None:
r"""
Add a parameter group to the :class:`Optimizer` 's ``param_groups``.
@ -586,7 +586,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
def _verify_params_per_rank(
self,
params_per_rank: List[List[torch.Tensor]],
params_per_rank: list[list[torch.Tensor]],
) -> None:
r"""
Verify ``params_per_rank`` for :meth:`_partition_parameters`.
@ -619,7 +619,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
)
def _partition_param_group(
self, param_group: Dict[str, Any], params_per_rank: List[List[torch.Tensor]]
self, param_group: dict[str, Any], params_per_rank: list[list[torch.Tensor]]
) -> None:
r"""
Partition the parameter group ``param_group`` according to ``params_per_rank``.
@ -641,8 +641,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
def _partition_parameters(
self,
params_per_rank: Optional[List[List[torch.Tensor]]] = None,
) -> List[List[Dict]]:
params_per_rank: Optional[list[list[torch.Tensor]]] = None,
) -> list[list[dict]]:
r"""
Partitions parameters across distributed data parallel ranks.
@ -674,7 +674,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
self._partition_parameters_cache = [[] for _ in range(self.world_size)]
sizes = [0] * self.world_size
for param_group in self.param_groups:
param_group_params_per_rank: List[List] = [
param_group_params_per_rank: list[list] = [
[] for _ in range(self.world_size)
]
# Sort the parameters by size (largest first)
@ -712,7 +712,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
return self._partition_parameters_cache
@property
def _param_to_rank(self) -> Dict[torch.Tensor, int]:
def _param_to_rank(self) -> dict[torch.Tensor, int]:
r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition."""
if len(self._param_to_rank_cache) == 0:
for rank, param_groups in enumerate(self._partition_parameters()):
@ -722,7 +722,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
return self._param_to_rank_cache
@property
def _param_to_index(self) -> Dict[torch.Tensor, int]:
def _param_to_index(self) -> dict[torch.Tensor, int]:
r"""
:class:`dict` mapping parameters to their indices in the global optimizer state.
@ -737,7 +737,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
return self._param_to_index_cache
@property
def _index_to_param(self) -> List[torch.Tensor]:
def _index_to_param(self) -> list[torch.Tensor]:
r"""List mapping parameter indices in the global optimizer scheme to the actual params."""
if len(self._index_to_param_cache) == 0:
self._index_to_param_cache = list(
@ -811,7 +811,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
@property
def _device_to_params_per_rank(
self,
) -> Dict[torch.device, List[List[torch.Tensor]]]:
) -> dict[torch.device, list[list[torch.Tensor]]]:
r"""
Return device parameters assigned per rank.
@ -854,8 +854,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
def _get_min_index(
self,
values: List[int],
disallowed_indices: Optional[Set[int]] = None,
values: list[int],
disallowed_indices: Optional[set[int]] = None,
) -> int:
r"""
Return ``values.index(min(values))``, except only uses one pass.
@ -881,10 +881,10 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
def _assign_bucket_subset_to_rank(
self,
bucket_index: int,
bucket_params: List[torch.Tensor],
bucket_params: list[torch.Tensor],
bucket_offset: int,
assigned_rank: int,
assigned_ranks_per_bucket: List[Set[int]],
assigned_ranks_per_bucket: list[set[int]],
) -> None:
r"""
Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information.
@ -919,7 +919,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
self._overlap_info.num_bucket_assignments += 1
@property
def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]:
def _bucket_assignments_per_rank(self) -> list[dict[int, _DDPBucketAssignment]]:
r"""
Return DDP bucket parameters assigned per rank.
@ -1015,7 +1015,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
def _local_step(
self,
gradients: Optional[List[Optional[torch.Tensor]]] = None,
gradients: Optional[list[Optional[torch.Tensor]]] = None,
closure: Optional[Callable[[], float]] = None,
**kwargs: Any,
) -> Optional[float]:
@ -1148,7 +1148,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
r"""Return process group."""
return self.process_group
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
r"""
Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed.
@ -1186,7 +1186,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
self._sync_param_groups(state_dict["param_groups"], self.param_groups)
self._sync_param_groups(self.param_groups, self.optim.param_groups)
def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> dict[str, Any]:
r"""
Return the last global optimizer state known to this rank.
@ -1252,8 +1252,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
@staticmethod
def _sync_param_groups(
src_param_groups: List[Dict[Any, Any]],
dst_param_groups: List[Dict[Any, Any]],
src_param_groups: list[dict[Any, Any]],
dst_param_groups: list[dict[Any, Any]],
) -> None:
r"""
Sync the attributes from the source parameter groups to the destination parameter groups.
@ -1381,7 +1381,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
def _verify_and_init_params(
self,
params: Any,
) -> Union[List[torch.Tensor], List[dict]]:
) -> Union[list[torch.Tensor], list[dict]]:
r"""
Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters.
@ -1469,7 +1469,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
f"{other_typename}"
)
def _get_is_trainable_mask(self) -> List[bool]:
def _get_is_trainable_mask(self) -> list[bool]:
r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not."""
return list(map(_is_trainable, self._all_params))

View File

@ -7,7 +7,7 @@ from collections import defaultdict
from enum import Enum
from inspect import Parameter, Signature, signature
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.fx as fx
@ -116,7 +116,7 @@ def _insert_stage_symbolic_backward(
output_node: fx.Node,
):
# Collect metadata about tuple output values. TODO: move this to split_module or FX IR
tuples: Dict[fx.Node, Tuple] = {}
tuples: dict[fx.Node, tuple] = {}
for node in reversed(g.nodes):
if node.op == "call_function":
# In the forward pass, only emit placeholder, module calls, and
@ -155,7 +155,7 @@ def _insert_stage_symbolic_backward(
# We will only emit backward operations for nodes that can contribute
# to the specified loss value.
live_nodes = {loss_node: None}
val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
val_to_grad: dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
def assign_or_accumulate_grad(forward_node, grad_value):
if forward_node in val_to_grad and forward_node.op != "placeholder":
@ -349,7 +349,7 @@ class MultiUseParameterConfig(Enum):
REPLICATE = 2
MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]]
MultiUseParamSpec = Union[MultiUseParameterConfig, dict[str, MultiUseParameterConfig]]
class DetachExecutor(fx.Interpreter):
@ -432,7 +432,7 @@ class _LinearNodeList:
def to_graph(self):
graph = fx.Graph()
ref_str_to_node: Dict[str, fx.Node] = {}
ref_str_to_node: dict[str, fx.Node] = {}
def ref_to_node(arg):
if isinstance(arg, _NodeReference):
@ -557,14 +557,14 @@ class Pipe(torch.nn.Module):
# Map parameter value to a dictionary that maps the user pipeline module
# to the local qualname within that module
params_to_users: Dict[torch.nn.Parameter, Dict[str, str]] = {}
params_to_users: dict[torch.nn.Parameter, dict[str, str]] = {}
for m_qualname, mod in self.split_gm.named_children():
for p_qualname, param in mod.named_parameters():
params_to_users.setdefault(param, {})
params_to_users[param][m_qualname] = p_qualname
self.replicated_params: List[Dict[str, str]] = [
self.replicated_params: list[dict[str, str]] = [
use_mapping
for _, use_mapping in params_to_users.items()
if len(use_mapping) > 1
@ -645,7 +645,7 @@ class Pipe(torch.nn.Module):
@staticmethod
def _number_and_count_forward_stages(gm: fx.GraphModule):
num_stages = 0
found_idxs: Dict[int, None] = {}
found_idxs: dict[int, None] = {}
for node in gm.graph.nodes:
if node.op == "call_module" and node.target.startswith("submod_"):
node.meta["stage_idx"] = int(node.target[len("submod_") :])
@ -693,7 +693,7 @@ class Pipe(torch.nn.Module):
# Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving
# parameters relies on the invariant that parameter accesses happen once. This is not necessarily
# the case (especially with custom tracers), so fix that up here.
get_attr_nodes: Dict[str, fx.Node] = {}
get_attr_nodes: dict[str, fx.Node] = {}
for node in traced.graph.nodes: # type: ignore[union-attr]
if node.op == "get_attr":
get_attr_nodes.setdefault(node.target, node)
@ -868,7 +868,7 @@ class Pipe(torch.nn.Module):
# [aliasing] store tensor id -> list of FQNs, built from state dict
# Also assign non-persistent buffers
id_to_fqns: Dict[int, Set[str]] = defaultdict(set)
id_to_fqns: dict[int, set[str]] = defaultdict(set)
for fqn, tensor in mod.state_dict(keep_vars=True).items():
id_to_fqns[id(tensor)].add(fqn)
for fqn, tensor in mod.named_buffers():
@ -878,7 +878,7 @@ class Pipe(torch.nn.Module):
# need to move the `get_attr` nodes from the root of the graph to those
# hierarchies.
# [aliasing] use id -> fqn mapping to list out all valid FQNs
inputs_to_state: Dict[str, List[str]] = {}
inputs_to_state: dict[str, list[str]] = {}
for attr in attr_nodes:
_, tensor = _recursive_getattr_with_parent(mod, attr.target)
fqns = list(id_to_fqns[id(tensor)])
@ -890,7 +890,7 @@ class Pipe(torch.nn.Module):
# [aliasing] for each submodule split, assign attributes on FQNs that may be used.
# We determine this based on whether or not the FQN attribute parent exists.
# i.e. if the last submodule exists, assign the attribute.
added_attributes: Dict[str, List[str]] = defaultdict(list)
added_attributes: dict[str, list[str]] = defaultdict(list)
for fqn, tensor in mod.state_dict(keep_vars=True).items():
for name, submod in split.named_children():
if isinstance(submod, fx.GraphModule):
@ -999,7 +999,7 @@ class Pipe(torch.nn.Module):
def _trace_with_export(
mod: torch.nn.Module,
example_args: tuple[Any, ...],
example_kwargs: Optional[Dict[str, Any]] = None,
example_kwargs: Optional[dict[str, Any]] = None,
) -> ExportedProgram:
logger.info("Tracing model ...")
try:
@ -1023,7 +1023,7 @@ class Pipe(torch.nn.Module):
def from_tracing(
mod: torch.nn.Module,
example_args: tuple[Any, ...],
example_kwargs: Optional[Dict[str, Any]] = None,
example_kwargs: Optional[dict[str, Any]] = None,
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
):
# If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
@ -1173,7 +1173,7 @@ def _split_after_forward(self, *args, **kwargs):
pipe_split()
def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]):
# TODO: make this implementation out-of-place?
for qualname, split_type in spec.items():
atoms = qualname.split(".")
@ -1200,8 +1200,8 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
def pipeline(
module: torch.nn.Module,
mb_args: tuple[Any, ...],
mb_kwargs: Optional[Dict[str, Any]] = None,
split_spec: Optional[Dict[str, SplitPoint]] = None,
mb_kwargs: Optional[dict[str, Any]] = None,
split_spec: Optional[dict[str, SplitPoint]] = None,
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
) -> Pipe:
"""

View File

@ -2,7 +2,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import collections
import logging
from typing import Any, Deque, Dict, Iterator, List, Optional, Set, Union
from collections.abc import Iterator
from typing import Any, Optional, Union
import torch
from torch.autograd.graph import GradientEdge, Node
@ -37,8 +38,8 @@ def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]:
def reverse_closure(
roots: List[Node], target_nodes: Set[Node], reverse_edges_dict
) -> tuple[Set[Node], Set[Node]]:
roots: list[Node], target_nodes: set[Node], reverse_edges_dict
) -> tuple[set[Node], set[Node]]:
"""
This function returns the reverse closure of the given roots,
i.e. the set of nodes that can be reached from the roots by following the
@ -46,9 +47,9 @@ def reverse_closure(
include in the closure.
"""
# Recurse until we reach a target node
closure: Set[Node] = set()
closure: set[Node] = set()
visited_target_nodes = set()
q: Deque[Node] = collections.deque()
q: collections.deque[Node] = collections.deque()
for node in roots:
if node is not None and node not in closure:
closure.add(node)
@ -67,10 +68,10 @@ def reverse_closure(
return closure, visited_target_nodes
def construct_reverse_graph(roots: List[Node]) -> Dict[Node, List[Node]]:
q: Deque[Node] = collections.deque()
root_seen: Set[Node] = set()
reverse_edges_dict: Dict[Node, List[Node]] = collections.defaultdict(list)
def construct_reverse_graph(roots: list[Node]) -> dict[Node, list[Node]]:
q: collections.deque[Node] = collections.deque()
root_seen: set[Node] = set()
reverse_edges_dict: dict[Node, list[Node]] = collections.defaultdict(list)
for node in roots:
if node is not None and node not in root_seen:
q.append(node)
@ -86,8 +87,8 @@ def construct_reverse_graph(roots: List[Node]) -> Dict[Node, List[Node]]:
def get_param_groups(
inputs: List[Node], params: List[Node], reverse_edges_dict
) -> List[Dict[str, Any]]:
inputs: list[Node], params: list[Node], reverse_edges_dict
) -> list[dict[str, Any]]:
"""
Given a list of inputs and a list of parameters, return a list of parameter
groups, where each group contains the parameters and the intermediates that
@ -103,12 +104,12 @@ def get_param_groups(
# reverse graph that starts with inputs, and goes up to the dOutput or the loss,
# but omits weights and any subgraphs connecting weights to this closure
inputs_closure, _ = reverse_closure(inputs, set(), reverse_edges_dict)
param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates
param_groups: dict[Node, dict[str, set]] = dict() # keyed on intermediates
for param in params:
closure, intersected = reverse_closure(
[param], inputs_closure, reverse_edges_dict
)
param_group: Dict[str, Set] = {
param_group: dict[str, set] = {
"params": {param},
"intermediates": intersected,
}
@ -124,8 +125,8 @@ def get_param_groups(
param_groups[input_node] = param_group
# Sanity check: union of all param_groups params should be equal to all params
union_params: Set[Node] = set()
seen_ids: Set[int] = set()
union_params: set[Node] = set()
seen_ids: set[int] = set()
unique_param_groups = []
for param_group in param_groups.values():
if id(param_group) not in seen_ids:
@ -140,11 +141,11 @@ def get_param_groups(
def stage_backward_input(
stage_outputs_or_loss: List[torch.Tensor],
output_grads: Optional[List[torch.Tensor]],
input_values: List[torch.Tensor],
stage_outputs_or_loss: list[torch.Tensor],
output_grads: Optional[list[torch.Tensor]],
input_values: list[torch.Tensor],
weights: Iterator[Parameter],
) -> tuple[tuple[Optional[torch.Tensor], ...], List[Dict[str, Any]]]:
) -> tuple[tuple[Optional[torch.Tensor], ...], list[dict[str, Any]]]:
"""
Compute the gradients for only the stage inputs with
respect to the stage outputs (if non-last stage) or loss (if last stage)
@ -155,13 +156,13 @@ def stage_backward_input(
Detaching the stage_outputs_or_loss at the end of this function is important as
it frees up the memory that the autograd graph is anticipating to be used later (but doesn't actually need).
"""
stage_output_grad_fns: List[Node] = list(
stage_output_grad_fns: list[Node] = list(
filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs_or_loss))
)
stage_input_grad_fns: List[Node] = list(
stage_input_grad_fns: list[Node] = list(
filter(None, map(_get_grad_fn_or_grad_acc, input_values))
)
weight_grad_fns: List[Node] = list(
weight_grad_fns: list[Node] = list(
filter(None, map(_get_grad_fn_or_grad_acc, weights))
)
@ -222,11 +223,11 @@ def stage_backward_input(
def stage_backward_weight(
weights: Iterator[Parameter], param_groups: List[Dict[str, Any]], retain_graph=False
weights: Iterator[Parameter], param_groups: list[dict[str, Any]], retain_graph=False
) -> tuple[Optional[torch.Tensor], ...]:
# map weights to param_group_weights
grad_acc_to_weight = {}
weight_grads: List[Optional[torch.Tensor]] = []
weight_grads: list[Optional[torch.Tensor]] = []
for index, weight in enumerate(weights):
grad_acc = _get_grad_fn_or_grad_acc(weight)
grad_acc_to_weight[grad_acc] = weight, index
@ -273,7 +274,7 @@ def stage_backward(
stage_output,
output_grads,
input_values,
outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used
outputs_with_grads_idxs: Optional[list[int]] = None, # deprecated, not used
) -> tuple[Optional[torch.Tensor], ...]:
"""
This is a helper function to:
@ -293,8 +294,8 @@ def stage_backward(
try:
# stage_output may be a composite datatype like dict. Extract all individual
# tensor values here
stage_output_tensors: List[torch.Tensor] = []
output_grad_tensors: List[Optional[torch.Tensor]] = []
stage_output_tensors: list[torch.Tensor] = []
output_grad_tensors: list[Optional[torch.Tensor]] = []
def extract_tensors_with_grads(
output_val,
@ -353,7 +354,7 @@ def stage_backward(
)
# Extract gradients wrt the input values
grad_inputs: List[Optional[torch.Tensor]] = []
grad_inputs: list[Optional[torch.Tensor]] = []
for val in input_values:
if isinstance(val, torch.Tensor):
grad_inputs.append(val.grad)

View File

@ -1,6 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections import defaultdict
from typing import Dict, List, Set
import torch
from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry
@ -9,10 +8,10 @@ from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry
def _outline_submodules(orig_graph: torch.fx.Graph) -> torch.fx.GraphModule:
# Create an empty GraphModule to hold the outlined modules
new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
seen_nodes: Dict[str, torch.fx.Node] = {}
seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list)
seen_attrs: Dict[str, Set[str]] = defaultdict(set)
created_modules: Dict[str, torch.nn.Module] = {}
seen_nodes: dict[str, torch.fx.Node] = {}
seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list)
seen_attrs: dict[str, set[str]] = defaultdict(set)
created_modules: dict[str, torch.nn.Module] = {}
_ModuleFrame(
orig_graph,
tuple(orig_graph.nodes),

View File

@ -2,7 +2,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from dataclasses import dataclass
from typing import List, Union
from typing import Union
import torch
from torch import fx
@ -75,8 +75,8 @@ def validate_tensor_metadata(desc, expected, given):
def validate_tensors_metadata(
desc,
expected_tensors: Union[List[torch.Tensor], tuple[torch.Tensor, ...]],
actual_tensors: Union[List[torch.Tensor], tuple[torch.Tensor, ...]],
expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
):
if len(expected_tensors) != len(actual_tensors):
raise PipeliningShapeError(

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
import torch
from torch.fx.node import map_aggregate
@ -92,7 +92,7 @@ class TensorChunkSpec:
@staticmethod
def from_dict(
chunk_dims: Dict[str, int],
chunk_dims: dict[str, int],
):
"""
A helper for creating a dictionary of `TensorChunkSpec` from a
@ -243,11 +243,11 @@ def _shard_dict_of_args(
def split_args_kwargs_into_chunks(
args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]],
kwargs: Optional[dict[str, Any]],
chunks: int,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
) -> tuple[List[Tuple], List[Dict]]:
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
) -> tuple[list[tuple], list[dict]]:
"""
Given a sequence of args and kwargs, split them into a number of chunks
according to their respective chunking specs.
@ -347,7 +347,7 @@ def split_args_kwargs_into_chunks(
def merge_chunks(
chunks: List[Any],
chunks: list[Any],
chunk_spec,
):
"""

View File

@ -9,17 +9,7 @@ import re
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from enum import Enum
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Set,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
import torch
import torch.distributed as dist
@ -163,7 +153,7 @@ class _Action(NamedTuple):
def _format_pipeline_order(
pipeline_order: Dict[int, List[Optional[_Action]]],
pipeline_order: dict[int, list[Optional[_Action]]],
error_step_number: Optional[int] = None,
) -> str:
"""
@ -230,8 +220,8 @@ class _PipelineSchedule(ABC):
n_microbatches: int,
loss_fn: Optional[Callable[..., torch.Tensor]] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
):
# From arguments
@ -256,7 +246,7 @@ class _PipelineSchedule(ABC):
self._has_backward = self._loss_fn is not None
# Holds the losses for each microbatch.
self._internal_losses: List[torch.Tensor] = []
self._internal_losses: list[torch.Tensor] = []
logger.info("Using %s", self.__class__.__name__)
def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
@ -302,10 +292,10 @@ class _PipelineSchedule(ABC):
@abstractmethod
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
@ -318,7 +308,7 @@ class _PipelineSchedule(ABC):
raise NotImplementedError
@abstractmethod
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
"""
Run one iteration of the pipeline schedule with *whole-batch* input.
Will chunk the input into microbatches automatically, and go through the
@ -333,10 +323,10 @@ class _PipelineSchedule(ABC):
def _check_inputs(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
):
"""
Pre-process/check inputs
@ -375,7 +365,7 @@ class _PipelineSchedule(ABC):
def _split_inputs(
self,
args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
):
"""
Splits a full-batch input into chunks (i.e. microbatches) and returns
@ -395,7 +385,7 @@ class _PipelineSchedule(ABC):
# Return a list of empty tuples/dicts with matching length as chunks
return [()] * self._n_microbatches, [{}] * self._n_microbatches
def _merge_outputs(self, output_chunks: List[Any]) -> Any:
def _merge_outputs(self, output_chunks: list[Any]) -> Any:
"""
Merge output chunks back to a batch state.
If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
@ -406,7 +396,7 @@ class _PipelineSchedule(ABC):
)
def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None):
def _batch_p2p(p2p_ops: list[dist.P2POp], desc: Optional[str] = None):
"""
Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
"""
@ -418,8 +408,8 @@ def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None):
def _sorted_batch_p2p(
p2p_ops: List[dist.P2POp], desc: Optional[str] = None
) -> Dict[int, dist.Work]:
p2p_ops: list[dist.P2POp], desc: Optional[str] = None
) -> dict[int, dist.Work]:
"""
Sorts the list of P2P ops by the peer rank, and then calls
batch_isend_irecv. Return a dictionary of works by peer rank. This function
@ -428,8 +418,8 @@ def _sorted_batch_p2p(
# Arrange p2p_ops by peer rank:
# int is the peer rank;
# List is the list of ops towards the peer
ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list)
work_by_peer: Dict[int, dist.Work] = {}
ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list)
work_by_peer: dict[int, dist.Work] = {}
if len(p2p_ops) == 0:
return work_by_peer
@ -461,8 +451,8 @@ class PipelineScheduleSingle(_PipelineSchedule):
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
):
# Init parent
@ -493,7 +483,7 @@ or equal to the number of stages ({self._num_stages})."
self._stage._prepare_backward_infra(self._n_microbatches)
self._stage_initialized = True
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
"""
Run one iteration of the pipeline schedule with *whole-batch* input.
Will chunk the input into microbatches automatically, and go through the
@ -535,10 +525,10 @@ class _ScheduleForwardOnly(PipelineScheduleSingle):
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
):
"""
Run one iteration of the pipeline schedule
@ -553,7 +543,7 @@ class _ScheduleForwardOnly(PipelineScheduleSingle):
self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
# Delay send waits
fwd_sends_to_wait: List[dist.Work] = []
fwd_sends_to_wait: list[dist.Work] = []
# Run microbatches
for i in range(self._n_microbatches):
@ -586,10 +576,10 @@ class ScheduleGPipe(PipelineScheduleSingle):
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
@ -604,7 +594,7 @@ class ScheduleGPipe(PipelineScheduleSingle):
self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
# Delay send waits
fwd_sends_to_wait: List[dist.Work] = []
fwd_sends_to_wait: list[dist.Work] = []
# Run microbatches
for i in range(self._n_microbatches):
@ -636,7 +626,7 @@ class ScheduleGPipe(PipelineScheduleSingle):
# Run backward
# Delay send waits
bwd_sends_to_wait: List[dist.Work] = []
bwd_sends_to_wait: list[dist.Work] = []
for i in range(self._n_microbatches):
with record_function(f"Backward {i}"):
ops = self._stage.get_bwd_recv_ops(i)
@ -677,10 +667,10 @@ class Schedule1F1B(PipelineScheduleSingle):
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
@ -820,9 +810,9 @@ class Schedule1F1B(PipelineScheduleSingle):
def _add_unshard_reshard(
compute_actions: List[Optional[_Action]],
compute_actions: list[Optional[_Action]],
max_active_stages: int = 3,
) -> List[_Action]:
) -> list[_Action]:
"""Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP.
UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation.
@ -836,11 +826,11 @@ def _add_unshard_reshard(
"""
def next_stage_indices(
count: int, next_actions: List[Optional[_Action]]
) -> List[int]:
count: int, next_actions: list[Optional[_Action]]
) -> list[int]:
"""Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
seen: Set[int] = set()
ret: List[int] = []
seen: set[int] = set()
ret: list[int] = []
for a in next_actions:
if a is not None and a.stage_index not in seen:
@ -850,8 +840,8 @@ def _add_unshard_reshard(
break
return ret
active_stages: Set[int] = set()
fsdp_aware_actions: List[_Action] = []
active_stages: set[int] = set()
fsdp_aware_actions: list[_Action] = []
def _unshard(stage_index: int):
active_stages.add(stage_index)
@ -890,8 +880,8 @@ def _add_unshard_reshard(
def _merge_bw(
compute_actions: List[Optional[_Action]],
) -> List[_Action]:
compute_actions: list[Optional[_Action]],
) -> list[_Action]:
"""Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops.
(note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD)
@ -925,12 +915,12 @@ def _merge_bw(
def _add_send_recv(
compute_actions: Dict[int, List[_Action]],
compute_actions: dict[int, list[_Action]],
stage_to_rank: Callable[[int], int],
num_stages: int,
) -> Dict[int, List[_Action]]:
comm_actions: Dict[int, List[_Action]] = {rank: [] for rank in compute_actions}
prev_actions: Dict[int, Set[_Action]] = {rank: set() for rank in compute_actions}
) -> dict[int, list[_Action]]:
comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions}
prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions}
def _has_comms(action: _Action) -> bool:
if action.computation_type == F:
@ -954,7 +944,7 @@ def _add_send_recv(
return send, recv
def _ready_to_schedule(
action: Optional[_Action], prev_actions: Set[_Action]
action: Optional[_Action], prev_actions: set[_Action]
) -> bool:
"""We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
This helps ensure a sane (non-hanging) ordering of sends and recvs.
@ -1030,7 +1020,7 @@ def _add_send_recv(
def _validate_schedule(
actions: Dict[int, List[Optional[_Action]]],
actions: dict[int, list[Optional[_Action]]],
pp_group_size: int,
num_stages: int,
num_microbatches: int,
@ -1043,7 +1033,7 @@ def _validate_schedule(
# We will count all the actions per stage and ensure they happen in a valid order
# (e.g. F before (B, I) before W for a given microbatch)
stage_actions: Dict[int, Dict[_ComputationType, Set]] = {
stage_actions: dict[int, dict[_ComputationType, set]] = {
stage_id: {
F: set(),
B: set(),
@ -1108,13 +1098,13 @@ class PipelineScheduleMulti(_PipelineSchedule):
def __init__(
self,
stages: List[_PipelineStageBase],
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
stage_index_to_group_rank: Optional[Dict[int, int]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
stage_index_to_group_rank: Optional[dict[int, int]] = None,
use_full_backward: Optional[bool] = None,
scale_grads: bool = True,
):
@ -1148,7 +1138,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
self._should_compute_loss = lambda stage: stage.is_last and has_loss
# This will be set during init of derived schedules
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
if use_full_backward is not None:
logger.warning(
@ -1199,7 +1189,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
self._n_microbatches,
)
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
"""
Run one iteration of the pipeline schedule with *whole-batch* input.
Will chunk the input into microbatches automatically, and go through the
@ -1235,10 +1225,10 @@ class PipelineScheduleMulti(_PipelineSchedule):
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
):
"""
Operate on the microbatches for looped schedules (multiple stages on each rank).
@ -1253,14 +1243,14 @@ class PipelineScheduleMulti(_PipelineSchedule):
# Based on the plan in Step 1 created in __init__:
# 2. Perform communication based on the pipeline_order
stage_index_to_stage: Dict[int, _PipelineStageBase] = {
stage_index_to_stage: dict[int, _PipelineStageBase] = {
stage.stage_index: stage for stage in self._stages
}
# determine prev_rank and next_rank based on which ranks are next to
# the stages in the pipeline_order
all_prev_ranks: Set[int] = set()
all_next_ranks: Set[int] = set()
all_prev_ranks: set[int] = set()
all_next_ranks: set[int] = set()
for stage_index in stage_index_to_stage.keys():
# TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections)
if stage_index > 0:
@ -1271,7 +1261,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
backward_counter: Counter[int] = Counter()
for time_step, action in enumerate(self.pipeline_order[self.rank]):
try:
ops: List[dist.P2POp] = []
ops: list[dist.P2POp] = []
if action is not None:
computation_type = action.computation_type
mb_index = action.microbatch_index
@ -1432,7 +1422,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
def _load_actions(
self,
actions: Dict[int, List[Optional[_Action]]],
actions: dict[int, list[Optional[_Action]]],
format: str = "compute_only",
):
"""
@ -1442,7 +1432,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
assert (
self.stage_index_to_group_rank is not None
), "stage_index_to_group_rank is required for PipelineScheduleRuntime"
self.pipeline_order_with_comms: Dict[int, List[_Action]] = {}
self.pipeline_order_with_comms: dict[int, list[_Action]] = {}
if format == "compute_comms":
for rank in actions:
self.pipeline_order_with_comms[rank] = []
@ -1507,10 +1497,10 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
):
"""
Operate on the microbatches for looped schedules (multiple stages on each rank).
@ -1524,7 +1514,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
# Based on the plan in Step 1 created in __init__:
# 2. Perform communication based on the pipeline_order
stage_index_to_stage: Dict[int, _PipelineStageBase] = {
stage_index_to_stage: dict[int, _PipelineStageBase] = {
stage.stage_index: stage for stage in self._stages
}
@ -1533,14 +1523,14 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
), "Must call _load_actions() before calling _step_microbatches()"
# recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
bwd_recv_ops: Dict[tuple[int, int], Work] = {}
fwd_recv_ops: Dict[tuple[int, int], Work] = {}
bwd_recv_ops: dict[tuple[int, int], Work] = {}
fwd_recv_ops: dict[tuple[int, int], Work] = {}
# send ops should be waited on before step() exists, mainly for hygeine
send_ops: List[Work] = []
send_ops: list[Work] = []
# we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages
unshard_ops: Dict[int, UnshardHandle] = {}
unshard_ops: dict[int, UnshardHandle] = {}
unsharded_stages = set()
def _assert_unsharded(stage_idx: int):
@ -1751,10 +1741,10 @@ class ScheduleLoopedBFS(PipelineScheduleMulti):
def __init__(
self,
stages: List[_PipelineStageBase],
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Union[Callable, _Loss]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
):
super().__init__(
@ -1768,7 +1758,7 @@ class ScheduleLoopedBFS(PipelineScheduleMulti):
# 1. Create the pipeline_order (all ranks do this calculation)
# This will be used to keep track of the current state of the entire pipeline
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
# ========================================================================
for rank in range(self.pp_group_size):
rank_ops = self._calculate_single_rank_operations(rank)
@ -1782,7 +1772,7 @@ class ScheduleLoopedBFS(PipelineScheduleMulti):
# Store the list of operations used for that rank
# Pre-padding, rank starts with no-ops based on the warmup.
rank_ops: List[Optional[_Action]] = [None for _ in range(rank)]
rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
for stage_index in stage_indices:
rank_ops.extend(
@ -1816,13 +1806,13 @@ def _get_1f1b_rank_ops(
enable_zero_bubble=False,
):
# All stages start with handling microbatch 0
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
fwd_stage_mb_index: dict[int, int] = defaultdict(int)
bwd_stage_mb_index: dict[int, int] = defaultdict(int)
weight_stage_mb_index: dict[int, int] = defaultdict(int)
# Store the list of operations used for that rank
# Pre-padding, rank starts with no-ops based on the warmup.
rank_ops: List[Optional[_Action]] = [None for _ in range(rank)]
rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
# Formula:
@ -1960,12 +1950,12 @@ class ScheduleInterleaved1F1B(PipelineScheduleMulti):
def __init__(
self,
stages: List[_PipelineStageBase],
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
):
self.pp_group_size = stages[0].group_size
@ -1991,12 +1981,12 @@ class ScheduleInterleaved1F1B(PipelineScheduleMulti):
# 1. Create the pipeline_order (all ranks do this calculation)
# This will be used to keep track of the current state of the entire pipeline
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
for rank in range(self.pp_group_size):
rank_ops = self._calculate_single_rank_operations(rank)
self.pipeline_order[rank] = rank_ops
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
def get_rank_warmup_ops(rank):
# Warms up operations for last stage
warmups_ops_last_stage = (
@ -2069,12 +2059,12 @@ class ScheduleInterleavedZeroBubble(PipelineScheduleMulti):
def __init__(
self,
stages: List[_PipelineStageBase],
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
):
# TODO: we don't support Zero Bubble with torch.compile so we
@ -2109,7 +2099,7 @@ stage modules that have used torch.compile"
# 1. Create the pipeline_order (all ranks do this calculation)
# This will be used to keep track of the current state of the entire pipeline
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
for rank in range(self.pp_group_size):
rank_ops = self._calculate_single_rank_operations(rank)
self.pipeline_order[rank] = rank_ops
@ -2121,7 +2111,7 @@ stage modules that have used torch.compile"
self.n_local_stages * self.pp_group_size,
)
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
def get_rank_warmup_ops(rank):
# Warms up operations for last stage
warmups_ops_last_stage = (
@ -2198,10 +2188,10 @@ stage modules that have used torch.compile"
return (stage + 1, op, microbatch) not in seen_ops
return False
seen_ops: Set[tuple[int, _ComputationType, int]] = set()
result: Dict[int, List[Optional[_Action]]] = {}
next_pointer: Dict[int, int] = {}
bubbles_added: Dict[int, int] = {}
seen_ops: set[tuple[int, _ComputationType, int]] = set()
result: dict[int, list[Optional[_Action]]] = {}
next_pointer: dict[int, int] = {}
bubbles_added: dict[int, int] = {}
total_bubbles_added = 0
for rank in range(self.pp_group_size):
@ -2212,7 +2202,7 @@ stage modules that have used torch.compile"
while True:
should_stop = True
temp_seen_ops: Set[tuple[int, _ComputationType, int]] = set()
temp_seen_ops: set[tuple[int, _ComputationType, int]] = set()
for rank in range(self.pp_group_size):
timestamp = next_pointer[rank]
@ -2270,13 +2260,13 @@ class ScheduleZBVZeroBubble(PipelineScheduleMulti):
def __init__(
self,
stages: List[_PipelineStageBase],
stages: list[_PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
stage_index_to_group_rank: Optional[Dict[int, int]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
stage_index_to_group_rank: Optional[dict[int, int]] = None,
scale_grads: bool = True,
):
self.pp_group_size = stages[0].group_size
@ -2303,16 +2293,16 @@ class ScheduleZBVZeroBubble(PipelineScheduleMulti):
# 1. Create the pipeline_order (all ranks do this calculation)
# This will be used to keep track of the current state of the entire pipeline
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
for rank in range(self.pp_group_size):
rank_ops = self._calculate_single_rank_operations(rank)
self.pipeline_order[rank] = rank_ops
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
# max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least
# as large of the number of microbatches needed to fully utilize the pipeline
n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches)
rank_ops: List[Optional[_Action]] = [None for _ in range(rank)]
rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
# Forward and backward action counts for stage chunk 0 and chunk 1
f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0
@ -2469,11 +2459,11 @@ def _simulate_comms_compute(
rank: [a for a in pipeline_order[rank] if a is not None]
for rank in sorted(pipeline_order)
}
_schedule: Dict[int, List[_Action | None]] = {
_schedule: dict[int, list[_Action | None]] = {
rank: [] for rank in sorted(pipeline_order)
}
_prev_ops_rank: Dict[int, Set[_Action]] = {rank: set() for rank in _schedule}
_prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule}
def add_to_schedule(rank: int, action: Optional[_Action]):
_schedule[rank].append(action)

View File

@ -3,7 +3,7 @@
import logging
import operator
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.distributed as dist
@ -160,10 +160,10 @@ class _PipelineStageBase(ABC):
self.dw_builder = dw_builder
# backward state
self.backward_state: Dict[int, tuple[Any, ...]] = {}
self.backward_state: dict[int, tuple[Any, ...]] = {}
# store dw_runner per microbatch_id
self.dw_runner: Dict[int, Callable[..., None]] = {}
self.dw_runner: dict[int, Callable[..., None]] = {}
# `group_rank` is rank in process group `group`.
self.group_rank = dist.get_rank(self.group)
@ -176,11 +176,11 @@ class _PipelineStageBase(ABC):
# Run time states
self._outputs_meta: Optional[tuple[torch.Tensor, ...]] = None
# map microbatch ID to list of forward tensor args
self.fwd_cache: Dict[int, tuple[Any, List[torch.Tensor]]] = {}
self.fwd_cache: dict[int, tuple[Any, list[torch.Tensor]]] = {}
# map microbatch ID to list of backward grad tensor args
self.bwd_cache: Dict[int, tuple[Optional[torch.Tensor], ...]] = {}
self.bwd_cache: dict[int, tuple[Optional[torch.Tensor], ...]] = {}
# Caching chunk outputs for final output merge or reduction
self.output_chunks: List[Any] = []
self.output_chunks: list[Any] = []
# Initialize has_backward to false; this will be set to true if loss
# function is passed to pipeline schedule
@ -189,16 +189,16 @@ class _PipelineStageBase(ABC):
self.log_prefix = f"[Stage {self.stage_index}]"
# Forward infra
self.args_recv_info: Dict[int, tuple[InputInfo, ...]] = {}
self.act_send_info: Dict[int, List] = {}
self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {}
self.act_send_info: dict[int, list] = {}
# Backward infra will created lazily
self.grad_recv_info: Dict = {}
self.grad_send_info: Optional[List] = None
self.grad_recv_info: dict = {}
self.grad_send_info: Optional[list] = None
# To be populated later by the Schedule
self.chunks: Optional[int] = None
self.stage_index_to_group_rank: Dict[int, int] = {
self.stage_index_to_group_rank: dict[int, int] = {
i: i % self.group_size for i in range(self.num_stages)
}
@ -258,12 +258,12 @@ class _PipelineStageBase(ABC):
def _create_grad_send_info(
self,
args_recv_info: Tuple,
) -> List[Optional[int]]:
args_recv_info: tuple,
) -> list[Optional[int]]:
"""
Create a list of stage indices to send gradients to.
"""
grad_send_info: List[Optional[int]] = []
grad_send_info: list[Optional[int]] = []
def map_recv_to_send(a):
# Note: we send gradients back to previous stage as long as in
@ -286,7 +286,7 @@ class _PipelineStageBase(ABC):
self,
num_microbatches: int,
args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
) -> tuple[Any, ...]:
raise NotImplementedError
@ -303,19 +303,19 @@ class _PipelineStageBase(ABC):
@abstractmethod
def _create_grad_recv_info(
self,
act_send_info: Dict,
act_send_info: dict,
) -> tuple[_RecvInfo, ...]:
raise NotImplementedError
def _get_recv_ops(
self,
recv_infos: tuple[InputInfo, ...],
) -> List[dist.P2POp]:
) -> list[dist.P2POp]:
"""
Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`.
Returns a list of ops that correspond to the recv infos.
"""
ops: List[dist.P2POp] = []
ops: list[dist.P2POp] = []
for info in recv_infos:
if not isinstance(info, _RecvInfo):
continue
@ -410,7 +410,7 @@ class _PipelineStageBase(ABC):
), f"Expected a recv info, got {type(info)}"
info.buffer = tensor
def get_fwd_recv_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]:
def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
"""
Returns a list of ops that are needed to receive the input arguments
for this stage.
@ -419,7 +419,7 @@ class _PipelineStageBase(ABC):
return self._get_recv_ops(recv_infos)
def get_bwd_recv_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]:
def get_bwd_recv_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]:
"""
Returns a list of ops that are needed to receive the gradients
for this stage.
@ -430,7 +430,7 @@ class _PipelineStageBase(ABC):
recv_infos = self.grad_recv_info[bwd_chunk_id]
return self._get_recv_ops(recv_infos)
def get_fwd_send_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]:
def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
"""
Get the activation send ops for current stage's forward.
"""
@ -439,7 +439,7 @@ class _PipelineStageBase(ABC):
# `act_send_info`
output_tuple = output if type(output) is tuple else (output,)
ops: List[dist.P2POp] = []
ops: list[dist.P2POp] = []
for idx, out in enumerate(output_tuple):
dst_stages = self.act_send_info[idx]
@ -462,7 +462,7 @@ class _PipelineStageBase(ABC):
return ops
def get_bwd_send_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]:
def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]:
"""
Get the gradient send ops for current stage's backward.
"""
@ -479,7 +479,7 @@ class _PipelineStageBase(ABC):
# `grad_send_info` is a mirror of `args_recv_info`
self.grad_send_info = self._create_grad_send_info(self.args_recv_info[0])
ops: List[dist.P2POp] = []
ops: list[dist.P2POp] = []
grads_input = self.bwd_cache.pop(bwd_chunk_id)
for grad, grad_recv_stage in zip(grads_input, self.grad_send_info):
if isinstance(grad, torch.Tensor) and grad_recv_stage is not None:
@ -593,9 +593,9 @@ class _PipelineStageBase(ABC):
def backward_maybe_with_nosync(
self,
backward_type,
bwd_kwargs: Dict,
bwd_kwargs: dict,
last_backward: bool = False,
) -> tuple[tuple[Optional[torch.Tensor], ...], Optional[List[Dict[str, Any]]]]:
) -> tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]]:
"""
Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the
other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but
@ -607,7 +607,7 @@ class _PipelineStageBase(ABC):
backward_type,
) -> Callable[
[],
tuple[tuple[Optional[torch.Tensor], ...], Optional[List[Dict[str, Any]]]],
tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]],
]:
if backward_type == "full":
return lambda: (
@ -686,7 +686,7 @@ class _PipelineStageBase(ABC):
self,
fwd_chunk_id: int,
args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
):
"""
Perform forward pass on the stage with one microbatch.
@ -817,7 +817,7 @@ class _PipelineStageBase(ABC):
"full", bwd_kwargs, last_backward=last_backward
)
else:
param_groups: List[Dict[str, Any]] | None = None
param_groups: list[dict[str, Any]] | None = None
# Skip the backward for the first stage since we will perform the weight update with
# autograd.backward in backward_weight_one_chunk
if not self.is_first:
@ -986,7 +986,7 @@ class _PipelineStage(_PipelineStageBase):
)
# Create mapping from stage name to stage index
self.submod_to_stage_index: Dict[str, int] = {}
self.submod_to_stage_index: dict[str, int] = {}
for i, node in enumerate(submod_nodes):
self.submod_to_stage_index.setdefault(node.name, i)
@ -1010,7 +1010,7 @@ class _PipelineStage(_PipelineStageBase):
self,
num_microbatches: int,
args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
) -> tuple[Any, ...]:
"""
Create send/recv infrastructures for activations (during forward)
@ -1084,7 +1084,7 @@ class _PipelineStage(_PipelineStageBase):
buffer,
)
args_recv_info: List[InputInfo] = []
args_recv_info: list[InputInfo] = []
# Filter out placeholder nodes from `self.submod` (a GraphModule)
placeholders = filter( # type: ignore[var-annotated]
lambda node: node.op == "placeholder", self.submod.graph.nodes # type: ignore[arg-type, union-attr]
@ -1134,7 +1134,7 @@ class _PipelineStage(_PipelineStageBase):
be consumed by multiple stages.
"""
# Output index: List of receiver ranks
act_send_info: Dict[int, List] = {}
act_send_info: dict[int, list] = {}
out_idx = 0
for user in self.node.users:
@ -1171,13 +1171,13 @@ class _PipelineStage(_PipelineStageBase):
def _create_grad_recv_info(
self,
act_send_info: Dict,
act_send_info: dict,
) -> tuple[_RecvInfo, ...]:
"""
Create a tuple of `_RecvInfo` for gradients.
"""
# Dict[output_index, _RecvInfo]
grad_recv_info: Dict[int, _RecvInfo] = {}
grad_recv_info: dict[int, _RecvInfo] = {}
output_node = self._get_output_node()
# The output node may take multiple args, meaning the submod having multiple output values.
@ -1275,7 +1275,7 @@ class PipelineStage(_PipelineStageBase):
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
):
super().__init__(submodule, stage_index, num_stages, device, group, dw_builder)
self.inputs: Optional[List[torch.Tensor]] = None
self.inputs: Optional[list[torch.Tensor]] = None
self.inputs_meta: Optional[tuple[torch.Tensor, ...]] = None
# Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) becuase it
# might be breaking for existing users.
@ -1313,7 +1313,7 @@ class PipelineStage(_PipelineStageBase):
)
# these are the buffers used in backwards send/recv, they are allocated later
self.outputs_grad: List[torch.Tensor] = []
self.outputs_grad: list[torch.Tensor] = []
def stage_global_rank(peer_rank):
return (
@ -1342,7 +1342,7 @@ class PipelineStage(_PipelineStageBase):
def _shape_inference(
self,
args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
):
if kwargs is None:
kwargs = {}
@ -1443,7 +1443,7 @@ class PipelineStage(_PipelineStageBase):
self,
num_microbatches: int,
args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
) -> tuple[Any, ...]:
# TODO move self.device to an argument from step API (from its input tensors)?
assert num_microbatches is not None, "TODO fix num_microbatches"
@ -1481,7 +1481,7 @@ class PipelineStage(_PipelineStageBase):
# Send info during forward for each activation
# only need the rank that is being sent to
self.act_send_info: Dict[int, List] = {}
self.act_send_info: dict[int, list] = {}
for idx in range(len(self.get_outputs_meta())):
# We assume we always send to stage + 1
@ -1494,7 +1494,7 @@ class PipelineStage(_PipelineStageBase):
def _create_grad_recv_info(
self,
act_send_info: Dict,
act_send_info: dict,
) -> tuple[_RecvInfo, ...]:
grad_recv_info: tuple[_RecvInfo, ...] = ()
if not self.is_last:

View File

@ -9,15 +9,16 @@ except ImportError as e:
import numbers
import os
import sys
from collections.abc import Iterator
from datetime import timedelta
from typing import Callable, Dict, Iterator, Optional
from typing import Callable, Optional
from torch.distributed import FileStore, Store, TCPStore
from .constants import default_pg_timeout
_rendezvous_handlers: Dict[str, Callable[..., Iterator[tuple[Store, int, int]]]] = {}
_rendezvous_handlers: dict[str, Callable[..., Iterator[tuple[Store, int, int]]]] = {}
__all__ = ["register_rendezvous_handler", "rendezvous"]
@ -54,14 +55,14 @@ def register_rendezvous_handler(scheme, handler):
# Query will have format "rank=0&world_size=1" and is
# converted into {"rank": 0, "world_size": 1}
def _query_to_dict(query: str) -> Dict[str, str]:
def _query_to_dict(query: str) -> dict[str, str]:
return {
pair[0]: pair[1]
for pair in (pair.split("=") for pair in filter(None, query.split("&")))
}
def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool:
def _get_use_libuv_from_query_dict(query_dict: dict[str, str]) -> bool:
# libuv is the default backend for TCPStore. To enable the non-libuv backend,
# user can explicitly specify ``use_libuv=0`` in the URL parameter.
return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1"

View File

@ -3,8 +3,9 @@ import logging
import os
import threading
import warnings
from collections.abc import Generator
from datetime import timedelta
from typing import Generator, Tuple
from typing import Tuple
from urllib.parse import urlparse
import torch

View File

@ -7,7 +7,7 @@ import functools
import inspect
import logging
import threading
from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar
from typing import Any, Generic, TYPE_CHECKING, TypeVar
import torch
from torch._C._distributed_rpc import (
@ -115,9 +115,9 @@ class AllGatherStates:
# States used by `def _all_gather()`.
# `_ALL_WORKER_NAMES` is initialized on initializing RPC layer.
_ALL_WORKER_NAMES: Set[Any] = set()
_ALL_WORKER_NAMES: set[Any] = set()
_all_gather_dict_lock = threading.RLock()
_all_gather_sequence_id: Dict[str, int] = {}
_all_gather_sequence_id: dict[str, int] = {}
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(
AllGatherStates
)

View File

@ -3,7 +3,7 @@
import collections
import enum
from typing import cast, Dict, List, Set
from typing import cast
import torch
import torch.distributed as dist
@ -163,8 +163,8 @@ def _tensorpipe_validate_devices(devices, device_count):
def _tensorpipe_exchange_and_check_all_device_maps(
my_name, my_device_count, my_device_maps, my_devices, group
):
gathered: List[
tuple[str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]]
gathered: list[
tuple[str, int, dict[str, dict[torch.device, torch.device]], list[torch.device]]
] = [("", 0, {}, []) for _ in range(group.size())]
dist.all_gather_object(
gathered, (my_name, my_device_count, my_device_maps, my_devices), group
@ -253,7 +253,7 @@ def _validate_device_maps(
def _create_device_list(my_devices, my_device_maps, reverse_device_maps):
if not my_devices:
devices_set: Set[torch.device] = set()
devices_set: set[torch.device] = set()
for map_ in my_device_maps.values():
devices_set.update(map_.keys())
for map_ in reverse_device_maps.values():
@ -265,7 +265,7 @@ def _create_device_list(my_devices, my_device_maps, reverse_device_maps):
def _create_reverse_mapping(my_name, all_names, all_device_maps):
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {}
reverse_device_maps: dict[str, dict[torch.device, torch.device]] = {}
for node in all_names:
if my_name in all_device_maps[node]:
reverse_device_maps[node] = {

View File

@ -1,5 +1,4 @@
from datetime import timedelta
from typing import List
from torch._C._distributed_rpc import (
_DEFAULT_INIT_METHOD,
@ -22,4 +21,4 @@ DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2**31 - 1)
# Value indicating that timeout is not set for RPC call, and the default should be used.
UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT
__all__: List[str] = []
__all__: list[str] = []

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import torch
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
@ -23,10 +23,10 @@ def _to_device(device: DeviceType) -> torch.device:
def _to_device_map(
device_map: Dict[DeviceType, DeviceType]
) -> Dict[torch.device, torch.device]:
full_device_map: Dict[torch.device, torch.device] = {}
reverse_map: Dict[torch.device, torch.device] = {}
device_map: dict[DeviceType, DeviceType]
) -> dict[torch.device, torch.device]:
full_device_map: dict[torch.device, torch.device] = {}
reverse_map: dict[torch.device, torch.device] = {}
for k, v in device_map.items():
k, v = torch.device(k), torch.device(v)
if v in reverse_map:
@ -39,7 +39,7 @@ def _to_device_map(
return full_device_map
def _to_device_list(devices: List[DeviceType]) -> List[torch.device]:
def _to_device_list(devices: list[DeviceType]) -> list[torch.device]:
return list(map(_to_device, devices))
@ -83,10 +83,10 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
devices: Optional[List[DeviceType]] = None,
_transports: Optional[List] = None,
_channels: Optional[List] = None,
device_maps: Optional[dict[str, dict[DeviceType, DeviceType]]] = None,
devices: Optional[list[DeviceType]] = None,
_transports: Optional[list] = None,
_channels: Optional[list] = None,
):
full_device_maps = (
{}
@ -104,7 +104,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
full_device_list,
)
def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]):
def set_device_map(self, to: str, device_map: dict[DeviceType, DeviceType]):
r"""
Set device mapping between each RPC caller and callee pair. This
function can be called multiple times to incrementally add
@ -162,7 +162,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
super()._set_device_map(to, full_device_map)
def set_devices(self, devices: List[DeviceType]):
def set_devices(self, devices: list[DeviceType]):
r"""
Set local devices used by the TensorPipe RPC agent. When processing
CUDA RPC requests, the TensorPipe RPC agent will properly synchronize

View File

@ -2,7 +2,6 @@
# mypy: allow-untyped-defs
import itertools
from typing import List
import torch
from torch.autograd.profiler_legacy import profile
@ -13,7 +12,7 @@ from . import (
)
__all__: List[str] = []
__all__: list[str] = []
class _server_process_global_profile(profile):

View File

@ -398,7 +398,7 @@ import sys
import uuid
from argparse import ArgumentParser, REMAINDER
from importlib import metadata
from typing import Callable, List, Optional, Set, Type, Union
from typing import Callable, Optional, Union
import torch
from torch.distributed.argparse_util import check_env, env
@ -736,7 +736,7 @@ def get_use_env(args) -> bool:
return args.use_env
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]:
"""
Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
Provides plugin mechanism to provide custom implementation of LogsSpecs.
@ -770,7 +770,7 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
return logs_specs_cls
def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str]]:
def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]:
# If ``args`` not passed, defaults to ``sys.argv[:1]``
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
assert 0 < min_nodes <= max_nodes
@ -810,7 +810,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str
rdzv_endpoint = get_rdzv_endpoint(args)
ranks: Optional[Set[int]] = None
ranks: Optional[set[int]] = None
if args.local_ranks_filter:
try:
ranks = set(map(int, args.local_ranks_filter.split(",")))
@ -820,7 +820,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str
"--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
) from e
logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
logs_specs_cls: type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
logs_specs = logs_specs_cls(
log_dir=args.log_dir,
redirects=Std.from_str(args.redirects),

View File

@ -1,18 +1,9 @@
# mypy: allow-untyped-defs
import dataclasses
import traceback
from typing import (
Any,
Callable,
Container,
Dict,
List,
Optional,
OrderedDict,
overload,
Set,
TypeVar,
)
from collections import OrderedDict
from collections.abc import Container
from typing import Any, Callable, Optional, overload, TypeVar
import torch
import torch.distributed as dist
@ -43,8 +34,8 @@ def _pack_kwargs(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], tuple[str,
kwarg keys. The second tuple element gives the kwarg keys.
The second tuple element's length is at most the first tuple element's length.
"""
kwarg_keys: List[str] = []
flat_args: List[Any] = list(args)
kwarg_keys: list[str] = []
flat_args: list[Any] = list(args)
for k, v in kwargs.items():
kwarg_keys.append(k)
flat_args.append(v)
@ -75,7 +66,7 @@ def _cast_forward_inputs(
def _unpack_kwargs(
flat_args: tuple[Any, ...], kwarg_keys: tuple[str, ...]
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""See _pack_kwargs."""
assert len(kwarg_keys) <= len(
flat_args
@ -94,7 +85,7 @@ T = TypeVar("T", torch.Tensor, PackedSequence)
@overload
def _recursive_to(
inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool
) -> List[S]:
) -> list[S]:
...
@ -264,10 +255,10 @@ def _apply_to_tensors(fn, container):
def _to_kwargs(
inputs: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]],
kwargs: Optional[dict[str, Any]],
target_device: torch.device,
use_side_stream_for_tensor_copies: bool,
) -> tuple[tuple[Any, ...], tuple[Dict[str, Any], ...]]:
) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]:
moved_inputs = (
_recursive_to(inputs, target_device, use_side_stream_for_tensor_copies)
if inputs
@ -287,7 +278,7 @@ def _to_kwargs(
def _verify_param_shape_across_processes(
process_group: dist.ProcessGroup,
tensors: List[torch.Tensor],
tensors: list[torch.Tensor],
logger: Optional["dist.Logger"] = None,
):
return dist._verify_params_across_processes(process_group, tensors, logger)
@ -309,7 +300,7 @@ def _sync_module_states(
parameter shapes are consistent before running the synchronization. This can
be checked with ``_verify_param_shape_across_processes``.
"""
module_states: List[torch.Tensor] = []
module_states: list[torch.Tensor] = []
for name, param in module.named_parameters():
if name not in params_and_buffers_to_ignore:
module_states.append(param.detach())
@ -324,7 +315,7 @@ def _sync_module_states(
def _sync_params_and_buffers(
process_group: dist.ProcessGroup,
module_states: List[torch.Tensor],
module_states: list[torch.Tensor],
broadcast_bucket_size: int,
src: int,
) -> None:
@ -336,7 +327,7 @@ def _sync_params_and_buffers(
def _replace_by_prefix(
state_dict: Dict[str, Any],
state_dict: dict[str, Any],
old_prefix: str,
new_prefix: str,
) -> None:
@ -363,15 +354,15 @@ def _data_ptr_allocated(tensor: torch.Tensor) -> bool:
return tensor.untyped_storage().data_ptr() > 0
def _get_root_modules(modules: List[nn.Module]) -> List[nn.Module]:
def _get_root_modules(modules: list[nn.Module]) -> list[nn.Module]:
"""
Returns the modules in ``modules`` that are root modules (i.e.
parent-less) with respect to the set ``modules``. In other words, these
are the modules in ``modules`` that are the not child of any other
module in ``modules``.
"""
root_modules: List[nn.Module] = []
module_to_modules: Dict[nn.Module, Set[nn.Module]] = {
root_modules: list[nn.Module] = []
module_to_modules: dict[nn.Module, set[nn.Module]] = {
module: set(module.modules()) for module in modules
}
for candidate_module in modules: