mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
PEP585 update - torch/distributed (#145164)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
c6986ca2e1
commit
00ffeca1b1
|
|
@ -79,7 +79,7 @@ if is_available():
|
|||
finally:
|
||||
sys.stdin = _stdin
|
||||
|
||||
_breakpoint_cache: typing.Dict[int, typing.Any] = {}
|
||||
_breakpoint_cache: dict[int, typing.Any] = {}
|
||||
|
||||
def breakpoint(rank: int = 0, skip: int = 0):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from typing import List, Protocol, runtime_checkable
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ class _Checkpointable(Protocol): # noqa: PYI046
|
|||
This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface.
|
||||
"""
|
||||
|
||||
def __create_write_items__(self, fqn: str, object: object) -> List[object]:
|
||||
def __create_write_items__(self, fqn: str, object: object) -> list[object]:
|
||||
"""
|
||||
Return a list of WriteItems based on object's contents.
|
||||
"""
|
||||
|
|
@ -20,7 +20,7 @@ class _Checkpointable(Protocol): # noqa: PYI046
|
|||
"_Checkpointable._create_write_items is not implemented"
|
||||
)
|
||||
|
||||
def __create_chunk_list__(self) -> List[object]:
|
||||
def __create_chunk_list__(self) -> list[object]:
|
||||
"""
|
||||
Return a list of `ChunkStorageMetadata` based on object's contents.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, ContextManager, Dict, Generator, Optional
|
||||
from typing import Any, ContextManager, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -85,7 +86,7 @@ def checkpoint(module: nn.Module, **kwargs) -> nn.Module:
|
|||
)
|
||||
|
||||
def forward_pre_hook(
|
||||
module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
if checkpoint.state(module).enable_hook:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,9 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
overload,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Callable, Generic, Optional, overload, Protocol, TypeVar, Union
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
import torch
|
||||
|
|
@ -66,7 +55,7 @@ def contract() -> (
|
|||
|
||||
@overload
|
||||
def contract(
|
||||
state_cls: Type[_TState],
|
||||
state_cls: type[_TState],
|
||||
) -> Callable[
|
||||
[Callable[Concatenate[nn.Module, _P], Optional[nn.Module]]],
|
||||
_ContractFn[Concatenate[nn.Module, _P], _T, _TState],
|
||||
|
|
@ -75,7 +64,7 @@ def contract(
|
|||
|
||||
|
||||
def contract(
|
||||
state_cls: Type = _State,
|
||||
state_cls: type = _State,
|
||||
) -> Callable[
|
||||
[
|
||||
Callable[
|
||||
|
|
@ -153,21 +142,21 @@ def contract(
|
|||
# `func` is allowed to return different module instances than the
|
||||
# input modules as long as FQNs are preserved following the input
|
||||
# module order
|
||||
all_orig_named_params: List[Dict[str, nn.Parameter]] = []
|
||||
all_orig_named_buffers: List[Dict[str, torch.Tensor]] = []
|
||||
all_orig_named_modules: List[Dict[str, nn.Module]] = []
|
||||
all_orig_named_params: list[dict[str, nn.Parameter]] = []
|
||||
all_orig_named_buffers: list[dict[str, torch.Tensor]] = []
|
||||
all_orig_named_modules: list[dict[str, nn.Module]] = []
|
||||
|
||||
for module in modules:
|
||||
default_all_state: Dict[Callable, _State] = OrderedDict()
|
||||
default_registry: Dict[str, RegistryItem] = OrderedDict()
|
||||
all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
|
||||
default_all_state: dict[Callable, _State] = OrderedDict()
|
||||
default_registry: dict[str, RegistryItem] = OrderedDict()
|
||||
all_state: dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
|
||||
STATE_KEY, default_all_state
|
||||
)
|
||||
if not isinstance(all_state, dict):
|
||||
raise AssertionError(
|
||||
f"Distributed composable API states corrupted: {all_state}"
|
||||
)
|
||||
registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
|
||||
registry: dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
|
||||
REGISTRY_KEY, default_registry
|
||||
)
|
||||
if not isinstance(registry, dict):
|
||||
|
|
@ -195,9 +184,9 @@ def contract(
|
|||
else:
|
||||
updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type]
|
||||
|
||||
all_new_named_params: List[Dict[str, nn.Parameter]] = []
|
||||
all_new_named_buffers: List[Dict[str, torch.Tensor]] = []
|
||||
all_new_named_modules: List[Dict[str, nn.Module]] = []
|
||||
all_new_named_params: list[dict[str, nn.Parameter]] = []
|
||||
all_new_named_buffers: list[dict[str, torch.Tensor]] = []
|
||||
all_new_named_modules: list[dict[str, nn.Module]] = []
|
||||
for module in updated_modules:
|
||||
all_new_named_params.append(OrderedDict(module.named_parameters()))
|
||||
all_new_named_buffers.append(OrderedDict(module.named_buffers()))
|
||||
|
|
@ -212,7 +201,7 @@ def contract(
|
|||
f"Outputs: {num_new_modules} modules"
|
||||
)
|
||||
|
||||
def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str):
|
||||
def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str):
|
||||
if orig_fqns == new_fqns:
|
||||
return
|
||||
|
||||
|
|
@ -280,7 +269,7 @@ def contract(
|
|||
return inner
|
||||
|
||||
|
||||
def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]:
|
||||
def _get_registry(module: nn.Module) -> Optional[dict[str, RegistryItem]]:
|
||||
r"""
|
||||
Get an ``OrderedDict`` of composable APIs that have been applied to the
|
||||
``module``, indexed by the API name. If no API has been applied, then this
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import weakref
|
||||
from typing import Any, Dict, Iterable, List, NoReturn, Optional, Set
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, NoReturn, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -24,17 +25,17 @@ class _ReplicateState(_State):
|
|||
# TODO(@fegin): this variable is originally create for testing, we
|
||||
# should remove this if possible.
|
||||
self._orig_module = self.module
|
||||
self._param_names: List[str] = []
|
||||
self._param_names: list[str] = []
|
||||
self._no_sync: bool = False
|
||||
self._init_args: Optional[tuple[Any, ...]] = None
|
||||
self._init_kwargs: Dict[str, Any] = {}
|
||||
self._comm_hook_args: List[Any] = []
|
||||
self._init_kwargs: dict[str, Any] = {}
|
||||
self._comm_hook_args: list[Any] = []
|
||||
|
||||
def _collect_params(
|
||||
self,
|
||||
module: nn.Module,
|
||||
ignored_modules: Set[nn.Module],
|
||||
ignored_params: Set[nn.Parameter],
|
||||
ignored_modules: set[nn.Module],
|
||||
ignored_params: set[nn.Parameter],
|
||||
prefix: str = _ROOT_MODULE_PREFIX,
|
||||
) -> None:
|
||||
# skip if managed by fully_sharded API
|
||||
|
|
@ -76,7 +77,7 @@ class _ReplicateState(_State):
|
|||
def init(
|
||||
self,
|
||||
module: nn.Module,
|
||||
ignored_modules: Set[nn.Module],
|
||||
ignored_modules: set[nn.Module],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if self.has_initialized:
|
||||
|
|
@ -125,7 +126,7 @@ class _ReplicateState(_State):
|
|||
self._init_kwargs = kwargs
|
||||
|
||||
def forward_pre_hook(
|
||||
self, module: nn.Module, args: tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> Any:
|
||||
if self._init_args or self._init_kwargs:
|
||||
self.lazy_init()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import contextlib
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, cast, List, Optional, Type, TYPE_CHECKING, Union
|
||||
from typing import Any, cast, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -94,8 +94,8 @@ Functional collectives can accept any of these types to describe the ranks parti
|
|||
The different types will be desugared to a canonical format
|
||||
"""
|
||||
RANK_TYPES = Union[
|
||||
List[int],
|
||||
List[List[int]],
|
||||
list[int],
|
||||
list[list[int]],
|
||||
dist.ProcessGroup,
|
||||
DeviceMesh,
|
||||
tuple["dist.tensor.DeviceMesh", int],
|
||||
|
|
@ -331,8 +331,8 @@ def reduce_scatter_tensor_autograd(
|
|||
|
||||
|
||||
def all_reduce_coalesced(
|
||||
self: List[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""
|
||||
) -> List[torch.Tensor]:
|
||||
self: list[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Reduces a list of tensors across all machines in such a way that all get
|
||||
the final result.
|
||||
|
|
@ -359,8 +359,8 @@ def all_reduce_coalesced(
|
|||
|
||||
|
||||
def all_gather_into_tensor_coalesced(
|
||||
self: List[torch.Tensor], group: RANK_TYPES, tag: str = ""
|
||||
) -> List[torch.Tensor]:
|
||||
self: list[torch.Tensor], group: RANK_TYPES, tag: str = ""
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Gather a list of tensors across from all machines.
|
||||
|
||||
|
|
@ -388,12 +388,12 @@ def all_gather_into_tensor_coalesced(
|
|||
|
||||
|
||||
def reduce_scatter_tensor_coalesced(
|
||||
inputs: List[torch.Tensor],
|
||||
inputs: list[torch.Tensor],
|
||||
reduceOp: str,
|
||||
scatter_dim: List[int],
|
||||
scatter_dim: list[int],
|
||||
group: RANK_TYPES,
|
||||
tag: str = "",
|
||||
) -> List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Reduces a list of tensors across all machines in such a way that all get
|
||||
the final result, then scatter the results to corresponding ranks.
|
||||
|
|
@ -450,8 +450,8 @@ def _is_view_op(tgt):
|
|||
|
||||
def all_to_all_single(
|
||||
self: torch.Tensor,
|
||||
output_split_sizes: Optional[List[int]],
|
||||
input_split_sizes: Optional[List[int]],
|
||||
output_split_sizes: Optional[list[int]],
|
||||
input_split_sizes: Optional[list[int]],
|
||||
group: RANK_TYPES,
|
||||
tag: str = "",
|
||||
) -> torch.Tensor:
|
||||
|
|
@ -498,8 +498,8 @@ def all_to_all_single(
|
|||
|
||||
def all_to_all_single_autograd(
|
||||
self: torch.Tensor,
|
||||
output_split_sizes: Optional[List[int]],
|
||||
input_split_sizes: Optional[List[int]],
|
||||
output_split_sizes: Optional[list[int]],
|
||||
input_split_sizes: Optional[list[int]],
|
||||
group: RANK_TYPES,
|
||||
tag: str = "",
|
||||
) -> torch.Tensor:
|
||||
|
|
@ -535,7 +535,7 @@ def all_to_all_single_autograd(
|
|||
|
||||
def permute_tensor(
|
||||
self: torch.Tensor,
|
||||
src_dst: List[int],
|
||||
src_dst: list[int],
|
||||
group: RANK_TYPES,
|
||||
tag: str = "",
|
||||
) -> torch.Tensor:
|
||||
|
|
@ -608,7 +608,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
|||
return AsyncCollectiveTensor(elem)
|
||||
|
||||
def __coerce_same_metadata_as_tangent__(
|
||||
self, expected_metadata: Any, expected_type: Optional[Type] = None
|
||||
self, expected_metadata: Any, expected_type: Optional[type] = None
|
||||
):
|
||||
if expected_type is not torch.Tensor:
|
||||
return None
|
||||
|
|
@ -677,7 +677,7 @@ Utils and infrastructure for tracing support
|
|||
"""
|
||||
|
||||
|
||||
def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, List[int], int]:
|
||||
def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int]:
|
||||
"""
|
||||
_expand_group desugars the different RANK_TYPES types into a canonical format that is traceable.
|
||||
|
||||
|
|
@ -690,10 +690,10 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, List[int], int
|
|||
if TYPE_CHECKING:
|
||||
|
||||
def cast_listlistint(x):
|
||||
return cast(List[List[int]], x)
|
||||
return cast(list[list[int]], x)
|
||||
|
||||
def cast_listint(x):
|
||||
return cast(List[int], x)
|
||||
return cast(list[int], x)
|
||||
|
||||
else:
|
||||
# fake cast op for use at runtime since dynamo doesn't support real cast
|
||||
|
|
@ -705,7 +705,7 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, List[int], int
|
|||
def cast_listint(x):
|
||||
return x
|
||||
|
||||
rankset: List[int]
|
||||
rankset: list[int]
|
||||
if isinstance(group, list):
|
||||
if isinstance(group[0], list):
|
||||
nested_list = cast_listlistint(group)
|
||||
|
|
@ -1143,7 +1143,7 @@ def all_to_all_inplace(
|
|||
|
||||
|
||||
def all_gather_inplace(
|
||||
tensor_list: List[torch.Tensor],
|
||||
tensor_list: list[torch.Tensor],
|
||||
tensor: torch.Tensor,
|
||||
group=None,
|
||||
async_op=False,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
|
|
@ -60,7 +60,7 @@ def _reduce_scatter_tensor(
|
|||
input: torch.Tensor,
|
||||
reduce_op: str,
|
||||
tag: str,
|
||||
ranks: List[int],
|
||||
ranks: list[int],
|
||||
group_size: int,
|
||||
):
|
||||
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
|
||||
|
|
@ -73,10 +73,10 @@ def _reduce_scatter_tensor(
|
|||
|
||||
|
||||
def _reduce_scatter_tensor_coalesced(
|
||||
inputs: List[torch.Tensor],
|
||||
inputs: list[torch.Tensor],
|
||||
reduce_op: str,
|
||||
tag: str,
|
||||
ranks: List[int],
|
||||
ranks: list[int],
|
||||
group_size: int,
|
||||
):
|
||||
group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag)
|
||||
|
|
@ -90,10 +90,10 @@ def _reduce_scatter_tensor_coalesced(
|
|||
|
||||
def _all_to_all_single(
|
||||
input: torch.Tensor,
|
||||
output_split_sizes: Optional[List[int]],
|
||||
input_split_sizes: Optional[List[int]],
|
||||
output_split_sizes: Optional[list[int]],
|
||||
input_split_sizes: Optional[list[int]],
|
||||
tag: str,
|
||||
ranks: List[int],
|
||||
ranks: list[int],
|
||||
group_size: int,
|
||||
):
|
||||
if output_split_sizes is None or input_split_sizes is None:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from torch.distributed.remote_device import _remote_device
|
||||
|
||||
|
|
@ -25,14 +25,14 @@ class ShardMetadata:
|
|||
|
||||
__slots__ = ["shard_offsets", "shard_sizes", "placement"]
|
||||
|
||||
shard_offsets: List[int]
|
||||
shard_sizes: List[int]
|
||||
shard_offsets: list[int]
|
||||
shard_sizes: list[int]
|
||||
placement: Optional[_remote_device]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shard_offsets: List[int],
|
||||
shard_sizes: List[int],
|
||||
shard_offsets: list[int],
|
||||
shard_sizes: list[int],
|
||||
placement: Optional[Union[str, _remote_device]] = None,
|
||||
):
|
||||
self.shard_offsets = shard_offsets
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Iterator, Tuple, Union
|
||||
from collections.abc import Iterator
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Dict, List, Mapping, Union
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
import torch.optim as optim
|
||||
from torch import Tensor
|
||||
|
|
@ -28,7 +29,7 @@ class ShardedOptimizer(optim.Optimizer):
|
|||
**optimizer_kwargs: the key-word arguments to initialize the optimizer.
|
||||
|
||||
"""
|
||||
tensors: List[Tensor] = []
|
||||
tensors: list[Tensor] = []
|
||||
for value in named_params.values():
|
||||
if isinstance(value, ShardedTensor):
|
||||
tensors.extend(
|
||||
|
|
@ -72,7 +73,7 @@ class ShardedOptimizer(optim.Optimizer):
|
|||
"""
|
||||
self._optim.step(closure)
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returned state and param_groups will contain parameter keys
|
||||
instead of parameter indices like torch.optim.Optimizer.
|
||||
|
|
|
|||
|
|
@ -356,7 +356,7 @@ def randn(
|
|||
|
||||
|
||||
def init_from_local_shards(
|
||||
local_shards: List[Shard], *global_size, process_group=None, init_rrefs=False
|
||||
local_shards: list[Shard], *global_size, process_group=None, init_rrefs=False
|
||||
) -> ShardedTensor:
|
||||
"""
|
||||
Creates an :class:`ShardedTensor` from local shards and the global metadata.
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import warnings
|
|||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
from typing import Callable, cast, Dict, List, Optional, Sequence, TYPE_CHECKING
|
||||
from typing import Callable, cast, Optional, TYPE_CHECKING
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
|
@ -41,23 +41,25 @@ from .utils import (
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
|
||||
|
||||
# Tracking for sharded tensor objects.
|
||||
_sharded_tensor_lock = threading.Lock()
|
||||
_sharded_tensor_current_id = 0
|
||||
_sharded_tensor_map: Dict[int, weakref.ReferenceType[ShardedTensor]] = {}
|
||||
_sharded_tensor_map: dict[int, weakref.ReferenceType[ShardedTensor]] = {}
|
||||
|
||||
# Default sharded ops
|
||||
_SHARDED_OPS: Dict[Callable, Callable] = {}
|
||||
_SHARDED_OPS: dict[Callable, Callable] = {}
|
||||
|
||||
# Customized user ops
|
||||
_CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {}
|
||||
_CUSTOM_SHARDED_OPS: dict[Callable, Callable] = {}
|
||||
|
||||
|
||||
def _register_remote_shards(
|
||||
sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int
|
||||
sharded_tensor_id: int, rrefs: list[rpc.RRef[Shard]], rpc_rank: int
|
||||
):
|
||||
with _sharded_tensor_lock:
|
||||
if sharded_tensor_id not in _sharded_tensor_map:
|
||||
|
|
@ -75,7 +77,7 @@ def _register_remote_shards(
|
|||
class ShardedTensorBase(torch.Tensor):
|
||||
_sharding_spec: shard_spec.ShardingSpec
|
||||
_metadata: ShardedTensorMetadata
|
||||
_local_shards: List[Shard]
|
||||
_local_shards: list[Shard]
|
||||
|
||||
def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs):
|
||||
# Use __new__ to construct a wrapper tensor, for recording tensor
|
||||
|
|
@ -125,7 +127,7 @@ class ShardedTensorBase(torch.Tensor):
|
|||
"""
|
||||
return self._metadata
|
||||
|
||||
def local_shards(self) -> List[Shard]:
|
||||
def local_shards(self) -> list[Shard]:
|
||||
"""
|
||||
Returns a list of :class:`Shard' corresponding to the
|
||||
local shards for this rank. Returns an empty list if the current rank
|
||||
|
|
@ -136,7 +138,7 @@ class ShardedTensorBase(torch.Tensor):
|
|||
@classmethod
|
||||
def _init_from_local_shards_and_global_metadata(
|
||||
cls,
|
||||
local_shards: List[Shard],
|
||||
local_shards: list[Shard],
|
||||
sharded_tensor_metadata: ShardedTensorMetadata,
|
||||
sharding_spec=None,
|
||||
) -> ShardedTensorBase:
|
||||
|
|
@ -290,7 +292,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
self._sharded_tensor_id = None
|
||||
|
||||
self._process_group = self._normalize_pg(process_group)
|
||||
self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {}
|
||||
self._remote_shards: dict[int, list[rpc.RRef[Shard]]] = {}
|
||||
|
||||
def _post_init(self):
|
||||
# Initialize RPC if available.
|
||||
|
|
@ -351,7 +353,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
continue
|
||||
|
||||
if len(self.local_shards()) != 0:
|
||||
rrefs: List[rpc.RRef[Shard]] = [
|
||||
rrefs: list[rpc.RRef[Shard]] = [
|
||||
rpc.RRef(shard) for shard in self.local_shards()
|
||||
]
|
||||
fut = rpc.rpc_async(
|
||||
|
|
@ -430,7 +432,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
world_size = dist.get_world_size(self._process_group)
|
||||
rank_sizes = [0 for _ in range(world_size)]
|
||||
max_rank_size = 0
|
||||
shard_placement: Dict[ShardMetadata, tuple[int, int]] = {}
|
||||
shard_placement: dict[ShardMetadata, tuple[int, int]] = {}
|
||||
# collect sizes
|
||||
for shard_md in self.metadata().shards_metadata:
|
||||
shard_rank = cast(_remote_device, shard_md.placement).rank()
|
||||
|
|
@ -440,7 +442,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
rank_sizes[shard_rank] += shard_size(shard_md)
|
||||
max_rank_size = max(max_rank_size, rank_sizes[shard_rank])
|
||||
|
||||
gather_list: Optional[List[torch.Tensor]]
|
||||
gather_list: Optional[list[torch.Tensor]]
|
||||
if rank == dst:
|
||||
assert out is not None
|
||||
if enforce_dtype:
|
||||
|
|
@ -535,7 +537,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
return self
|
||||
|
||||
# if not, returns a copy of this object on CPU
|
||||
list_shards: List[Shard] = []
|
||||
list_shards: list[Shard] = []
|
||||
# move all local shards to cpu, and change metadata
|
||||
for shard in self._local_shards:
|
||||
cpu_tensor = shard.tensor.cpu(memory_format=memory_format) # type: ignore[call-arg]
|
||||
|
|
@ -594,7 +596,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
|
||||
current_device = torch.device(torch.cuda.current_device())
|
||||
# returns a copy of ShardedTensor on CUDA current device
|
||||
list_shards: List[Shard] = []
|
||||
list_shards: list[Shard] = []
|
||||
# move all local shards to current device, and change metadata
|
||||
# if local shards already on the current device, there's no
|
||||
# real data movement, only the metadata are copied.
|
||||
|
|
@ -683,7 +685,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
return self
|
||||
|
||||
# returns a copy of ShardedTensor on CUDA current device
|
||||
list_shards: List[Shard] = []
|
||||
list_shards: list[Shard] = []
|
||||
|
||||
for shard in self._local_shards:
|
||||
new_tensor = shard.tensor.to( # type: ignore[call-overload]
|
||||
|
|
@ -726,7 +728,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
@classmethod
|
||||
def _init_from_local_shards(
|
||||
cls,
|
||||
local_shards: List[Shard],
|
||||
local_shards: list[Shard],
|
||||
*global_size,
|
||||
process_group=None,
|
||||
init_rrefs=False,
|
||||
|
|
@ -746,7 +748,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
|
||||
# STEP 2. Validate metadata across ranks, and build a global sharded tensor
|
||||
# metadata by gathering local ShardedTensorMetadata
|
||||
gathered_metadatas: List[Optional[ShardedTensorMetadata]] = []
|
||||
gathered_metadatas: list[Optional[ShardedTensorMetadata]] = []
|
||||
if world_size > 1:
|
||||
gathered_metadatas = [None for _ in range(world_size)]
|
||||
|
||||
|
|
@ -866,7 +868,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
process_group = cls._normalize_pg(process_group)
|
||||
current_rank = dist.get_rank() # intentional to get global rank
|
||||
|
||||
local_shards: List[Shard] = []
|
||||
local_shards: list[Shard] = []
|
||||
for shard_metadata in sharded_tensor_metadata.shards_metadata:
|
||||
rank, _device = _parse_and_validate_remote_device(
|
||||
process_group, shard_metadata.placement
|
||||
|
|
@ -887,7 +889,7 @@ class ShardedTensor(ShardedTensorBase):
|
|||
@classmethod
|
||||
def _init_from_local_shards_and_global_metadata( # type: ignore[override]
|
||||
cls,
|
||||
local_shards: List[Shard],
|
||||
local_shards: list[Shard],
|
||||
sharded_tensor_metadata: ShardedTensorMetadata,
|
||||
process_group=None,
|
||||
init_rrefs=False,
|
||||
|
|
@ -1190,11 +1192,11 @@ class ShardedTensor(ShardedTensorBase):
|
|||
return self._metadata.tensor_properties.pin_memory
|
||||
|
||||
def _register_remote_shards(
|
||||
self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int
|
||||
self, remote_shards: list[rpc.RRef[Shard]], rpc_rank: int
|
||||
):
|
||||
self._remote_shards[rpc_rank] = remote_shards
|
||||
|
||||
def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]:
|
||||
def remote_shards(self) -> dict[int, list[rpc.RRef[Shard]]]:
|
||||
"""
|
||||
Returns a Dict[int, RRef] with keys being the RPC rank and values
|
||||
being RRefs to shards on that rank. Need to initialize the
|
||||
|
|
|
|||
|
|
@ -7,12 +7,11 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from torch.distributed._shard.sharded_tensor.logging_handlers import _log_handlers
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
def _get_or_create_logger() -> logging.Logger:
|
||||
|
|
|
|||
|
|
@ -7,11 +7,10 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
_log_handlers: Dict[str, logging.Handler] = {
|
||||
_log_handlers: dict[str, logging.Handler] = {
|
||||
"default": logging.NullHandler(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
|
|
@ -87,7 +86,7 @@ class ShardedTensorMetadata:
|
|||
"""
|
||||
|
||||
# Metadata about each shard of the Tensor
|
||||
shards_metadata: List[ShardMetadata] = field(default_factory=list)
|
||||
shards_metadata: list[ShardMetadata] = field(default_factory=list)
|
||||
|
||||
# Size of each dim of the overall Tensor.
|
||||
size: torch.Size = field(default=torch.Size([]))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -44,7 +43,7 @@ def build_reshard_metadata(
|
|||
st_size: torch.Size,
|
||||
sharding_spec: shard_spec.ShardingSpec,
|
||||
world_size: int,
|
||||
) -> tuple[List[ShardMetadata], List[int]]:
|
||||
) -> tuple[list[ShardMetadata], list[int]]:
|
||||
"""
|
||||
Based the given sharding spec, we calculate the offset and local shard size.
|
||||
We then build a ShardMetadata on top of the calculation result.
|
||||
|
|
@ -86,7 +85,7 @@ def reshuffle_local_shard(
|
|||
sharding_spec: shard_spec.ShardingSpec,
|
||||
resharding_spec: shard_spec.ShardingSpec,
|
||||
pg: ProcessGroup,
|
||||
) -> tuple[List[Shard], List[ShardMetadata]]:
|
||||
) -> tuple[list[Shard], list[ShardMetadata]]:
|
||||
"""
|
||||
Reshuffle the local shard directly when the reshard dim is same as the original
|
||||
sharding dim. Logically we do this in two step:
|
||||
|
|
@ -155,7 +154,7 @@ def reshard_local_shard(
|
|||
sharding_spec: shard_spec.ShardingSpec,
|
||||
resharding_spec: shard_spec.ShardingSpec,
|
||||
pg: ProcessGroup,
|
||||
) -> tuple[List[Shard], List[ShardMetadata]]:
|
||||
) -> tuple[list[Shard], list[ShardMetadata]]:
|
||||
"""
|
||||
Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is
|
||||
different from the original sharding dim, we need to do two steps logically:
|
||||
|
|
@ -198,7 +197,7 @@ def reshard_local_shard(
|
|||
|
||||
if rearrange_input:
|
||||
# Need to re-arrange reshard_dim of local_tensor before all2all.
|
||||
indices: List[int] = []
|
||||
indices: list[int] = []
|
||||
for metadata in shards_metadata:
|
||||
offset_start_idx = metadata.shard_offsets[reshard_dim]
|
||||
split_size = metadata.shard_sizes[reshard_dim]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
|
|
@ -43,7 +42,7 @@ class Shard:
|
|||
|
||||
@classmethod
|
||||
def from_tensor_and_offsets(
|
||||
cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int
|
||||
cls, tensor: torch.Tensor, shard_offsets: list[int], rank: int
|
||||
) -> "Shard":
|
||||
"""
|
||||
Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import collections.abc
|
||||
import copy
|
||||
from typing import List, Optional, Sequence, TYPE_CHECKING
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.distributed import distributed_c10d as c10d, rpc
|
||||
|
|
@ -109,13 +110,13 @@ def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True):
|
|||
|
||||
|
||||
def build_metadata_from_local_shards(
|
||||
local_shards: List[Shard],
|
||||
local_shards: list[Shard],
|
||||
global_size: torch.Size,
|
||||
current_rank: int,
|
||||
pg: c10d.ProcessGroup,
|
||||
) -> ShardedTensorMetadata:
|
||||
assert len(local_shards) > 0, "must have local shards!"
|
||||
local_shard_metadatas: List[ShardMetadata] = []
|
||||
local_shard_metadatas: list[ShardMetadata] = []
|
||||
|
||||
first_shard_dtype = local_shards[0].tensor.dtype
|
||||
first_shard_layout = local_shards[0].tensor.layout
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed._shard.sharder import Sharder
|
||||
|
|
@ -62,9 +62,9 @@ class ShardingPlan:
|
|||
>>> )
|
||||
"""
|
||||
|
||||
plan: Dict[str, Union[ShardingSpec, Sharder]]
|
||||
output_plan: Optional[Dict[str, ShardingSpec]] = None
|
||||
return_local_tensor: Optional[List[str]] = None
|
||||
plan: dict[str, Union[ShardingSpec, Sharder]]
|
||||
output_plan: Optional[dict[str, ShardingSpec]] = None
|
||||
return_local_tensor: Optional[list[str]] = None
|
||||
|
||||
|
||||
class ShardingPlanner(abc.ABC):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from torch.distributed._shard.metadata import ShardMetadata
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetad
|
|||
|
||||
|
||||
def _find_nd_overlapping_shards(
|
||||
shards: List[ShardMetadata], sharded_dims: List[int]
|
||||
shards: list[ShardMetadata], sharded_dims: list[int]
|
||||
) -> Optional[tuple[int, int]]:
|
||||
# Each rank has len(sharded_dims) tuples. Each tuple represent the
|
||||
# [begin, end] (inclusive) pair of that dimension.
|
||||
|
|
@ -55,7 +55,7 @@ def _find_nd_overlapping_shards(
|
|||
|
||||
|
||||
def _find_1d_overlapping_shards(
|
||||
shards: List[ShardMetadata], dim: int
|
||||
shards: list[ShardMetadata], dim: int
|
||||
) -> Optional[tuple[int, int]]:
|
||||
# (begin, end, index_in_shards). Begin and end are inclusive.
|
||||
intervals = [
|
||||
|
|
@ -69,7 +69,7 @@ def _find_1d_overlapping_shards(
|
|||
return None
|
||||
|
||||
|
||||
def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]):
|
||||
def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]):
|
||||
"""
|
||||
Ensures none of the shards overlap with each other.
|
||||
|
||||
|
|
@ -82,7 +82,7 @@ def validate_non_overlapping_shards_metadata(shards: List[ShardMetadata]):
|
|||
if not shards or len(shards) == 1:
|
||||
return
|
||||
|
||||
sharded_dims: List[int] = []
|
||||
sharded_dims: list[int] = []
|
||||
for dim in range(len(shards[0].shard_offsets)):
|
||||
for i in range(1, len(shards)):
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import functools
|
|||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, TYPE_CHECKING
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
|
||||
|
|
@ -95,7 +95,7 @@ class ShardingSpec(ABC):
|
|||
|
||||
|
||||
# Ops customized for a particular ShardingSpec.
|
||||
_CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {}
|
||||
_CUSTOM_SHARDING_SPEC_OPS: dict[str, dict[Callable, Callable]] = {}
|
||||
|
||||
|
||||
def _has_custom_op(sharding_spec, op):
|
||||
|
|
@ -148,7 +148,7 @@ class EnumerableShardingSpec(ShardingSpec):
|
|||
each shard. Note that none of the shards should overlap.
|
||||
"""
|
||||
|
||||
shards: List[ShardMetadata]
|
||||
shards: list[ShardMetadata]
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.shards) == 0:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import cast, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -53,7 +53,7 @@ class ChunkShardingSpec(ShardingSpec):
|
|||
ShardingDim = Union[int, str]
|
||||
|
||||
dim: ShardingDim
|
||||
placements: List[Union[torch.distributed._remote_device, str]]
|
||||
placements: list[Union[torch.distributed._remote_device, str]]
|
||||
|
||||
def __post_init__(self):
|
||||
self._verify_dim(self.dim)
|
||||
|
|
@ -134,7 +134,7 @@ class ChunkShardingSpec(ShardingSpec):
|
|||
local_metadata = None
|
||||
|
||||
tensors_to_scatter = cast(
|
||||
List[Optional[torch.Tensor]],
|
||||
list[Optional[torch.Tensor]],
|
||||
[None] * dist.get_world_size(process_group),
|
||||
)
|
||||
|
||||
|
|
@ -195,7 +195,7 @@ class ChunkShardingSpec(ShardingSpec):
|
|||
process_group, src_for_scatter
|
||||
)
|
||||
|
||||
tensors_to_scatter_: Optional[List[torch.Tensor]] = None
|
||||
tensors_to_scatter_: Optional[list[torch.Tensor]] = None
|
||||
if current_rank == src_rank:
|
||||
tensors_to_scatter_ = []
|
||||
for t in tensors_to_scatter:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
from typing import cast, List
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -222,7 +222,7 @@ def _validate_embedding_bag_param(args, kwargs):
|
|||
)
|
||||
if include_last_offset and offsets is None:
|
||||
raise ValueError('offsets is required for flag "include_last_offset"!')
|
||||
if include_last_offset and cast(List[int], offsets)[-1] != input.size(0):
|
||||
if include_last_offset and cast(list[int], offsets)[-1] != input.size(0):
|
||||
raise ValueError(
|
||||
'offsets need to have the input size in the end when the flag "include_last_offset" is on!'
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,19 +3,8 @@ import copy
|
|||
import io
|
||||
import math
|
||||
import weakref
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -94,7 +83,7 @@ def _iterate_state_dict(
|
|||
ranks_only: tuple[int, ...] = (),
|
||||
type_check: bool = True,
|
||||
non_blocking: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Iterate through the state dict, applying the given functions to each tensor type.
|
||||
|
||||
Args:
|
||||
|
|
@ -203,14 +192,14 @@ def _iterate_state_dict(
|
|||
|
||||
|
||||
def _gather_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
*,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
cpu_offload: bool = False,
|
||||
ranks_only: tuple[int, ...] = (),
|
||||
type_check: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Given a state_dict, this API gathers all the ShardedTensors or DTensors in
|
||||
the state_dict.
|
||||
|
|
@ -291,11 +280,11 @@ def _gather_state_dict(
|
|||
|
||||
|
||||
def _offload_state_dict_to_cpu(
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
*,
|
||||
ranks_only: tuple[int, ...] = (),
|
||||
type_check: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Given a state_dict, this API offload all the tensors to CPU memory.
|
||||
|
||||
|
|
@ -331,11 +320,11 @@ def _offload_state_dict_to_cpu(
|
|||
|
||||
|
||||
def _copy_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
copy_state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
copy_state_dict: dict[str, Any],
|
||||
non_blocking: bool = False,
|
||||
type_check: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Copies all tensors in a given state dict into a different state_dict with the
|
||||
same structure. Additionally, a copied state dict with the same value references
|
||||
|
|
@ -380,8 +369,8 @@ def _copy_state_dict(
|
|||
|
||||
|
||||
def _create_cpu_state_dict(
|
||||
state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
state_dict: dict[str, Any], pin_memory: bool = False, share_memory: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Given a state_dict, create another state_dict with the same structure and elements.
|
||||
However, all tensors in the returned state_dict are new tensors on CPU. These
|
||||
|
|
@ -449,8 +438,8 @@ def _create_cpu_state_dict(
|
|||
|
||||
|
||||
def _check_state_dict_similarity(
|
||||
state_dict: Dict[str, Any],
|
||||
compared_state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
compared_state_dict: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Given two state_dicts, check if the structures are the same. And
|
||||
|
|
@ -496,9 +485,9 @@ class _TensorInfo(NamedTuple):
|
|||
|
||||
|
||||
def _broadcast_tensors(
|
||||
full_state_dict: Dict[str, Any],
|
||||
local_state_dict: Dict[str, Any],
|
||||
keys: List[str],
|
||||
full_state_dict: dict[str, Any],
|
||||
local_state_dict: dict[str, Any],
|
||||
keys: list[str],
|
||||
device: torch.device,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
) -> None:
|
||||
|
|
@ -536,8 +525,8 @@ def _broadcast_tensors(
|
|||
|
||||
|
||||
def _distribute_tensors(
|
||||
local_state_dict: Dict[str, Any],
|
||||
keys: List[str],
|
||||
local_state_dict: dict[str, Any],
|
||||
keys: list[str],
|
||||
device: torch.device,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
) -> None:
|
||||
|
|
@ -578,8 +567,8 @@ def _distribute_tensors(
|
|||
|
||||
|
||||
def _broadcast_state_dict(
|
||||
full_state_dict: Dict[str, Any],
|
||||
local_state_dict: Dict[str, Any],
|
||||
full_state_dict: dict[str, Any],
|
||||
local_state_dict: dict[str, Any],
|
||||
device: torch.device,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
strict: bool = False,
|
||||
|
|
@ -637,8 +626,8 @@ def _broadcast_state_dict(
|
|||
|
||||
|
||||
def _distribute_state_dict(
|
||||
full_state_dict: Dict[str, Any],
|
||||
local_state_dict: Dict[str, Any],
|
||||
full_state_dict: dict[str, Any],
|
||||
local_state_dict: dict[str, Any],
|
||||
device: torch.device,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
) -> None:
|
||||
|
|
@ -672,8 +661,8 @@ def _distribute_state_dict(
|
|||
# DCP.
|
||||
PATH_ITEM = Union[str, int]
|
||||
OBJ_PATH = tuple[PATH_ITEM, ...]
|
||||
FLATTEN_MAPPING = Dict[str, OBJ_PATH]
|
||||
STATE_DICT_TYPE = Dict[str, Any]
|
||||
FLATTEN_MAPPING = dict[str, OBJ_PATH]
|
||||
STATE_DICT_TYPE = dict[str, Any]
|
||||
CONTAINER_TYPE = MutableMapping[PATH_ITEM, Any]
|
||||
|
||||
|
||||
|
|
@ -731,14 +720,14 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
|
|||
"""Set ``value`` in ``root_dict`` along the ``path`` object path."""
|
||||
cur_container = cast(CONTAINER_TYPE, root_dict)
|
||||
|
||||
def extend_list(lst: List[Any], idx: int) -> None:
|
||||
def extend_list(lst: list[Any], idx: int) -> None:
|
||||
while len(lst) <= idx:
|
||||
lst.append(None)
|
||||
|
||||
for i in range(1, len(path)):
|
||||
prev_key = path[i - 1]
|
||||
key = path[i]
|
||||
def_val: Union[CONTAINER_TYPE, List[Any]] = {} if type(key) == str else []
|
||||
def_val: Union[CONTAINER_TYPE, list[Any]] = {} if type(key) == str else []
|
||||
|
||||
if isinstance(cur_container, Mapping):
|
||||
cur_container = cast(
|
||||
|
|
@ -752,7 +741,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
|
|||
|
||||
key = path[-1]
|
||||
if type(key) == int:
|
||||
extend_list(cast(List[Any], cur_container), key)
|
||||
extend_list(cast(list[Any], cur_container), key)
|
||||
|
||||
cur_container[key] = value
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@ import math
|
|||
import os
|
||||
import socket
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
|
@ -15,7 +16,7 @@ from torch._C._autograd import DeviceType
|
|||
from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work
|
||||
|
||||
|
||||
_group_name_to_store: Dict[str, c10d.Store] = {}
|
||||
_group_name_to_store: dict[str, c10d.Store] = {}
|
||||
|
||||
|
||||
def enable_symm_mem_for_group(group_name: str) -> None:
|
||||
|
|
@ -95,7 +96,7 @@ def is_symm_mem_enabled_for_group(group_name: str) -> bool:
|
|||
return _is_test_mode or group_name in _group_name_to_store
|
||||
|
||||
|
||||
_group_name_to_workspace_tensor: Dict[str, Optional[torch.Tensor]] = {}
|
||||
_group_name_to_workspace_tensor: dict[str, Optional[torch.Tensor]] = {}
|
||||
|
||||
|
||||
def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory:
|
||||
|
|
@ -139,7 +140,7 @@ def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory:
|
|||
return _SymmetricMemory.rendezvous(tensor)
|
||||
|
||||
|
||||
_backend_streams: Dict[int, torch.cuda.Stream] = {}
|
||||
_backend_streams: dict[int, torch.cuda.Stream] = {}
|
||||
|
||||
|
||||
def _get_backend_stream(priority: int = 0) -> torch.cuda.Stream:
|
||||
|
|
@ -149,9 +150,9 @@ def _get_backend_stream(priority: int = 0) -> torch.cuda.Stream:
|
|||
|
||||
|
||||
def _pipelined_multi_all_gather_and_consume(
|
||||
shard: List[torch.Tensor],
|
||||
shard_consumer: Callable[[List[torch.Tensor], int], None],
|
||||
ag_out: List[torch.Tensor],
|
||||
shard: list[torch.Tensor],
|
||||
shard_consumer: Callable[[list[torch.Tensor], int], None],
|
||||
ag_out: list[torch.Tensor],
|
||||
group_name: str,
|
||||
ag_out_needed: bool = True,
|
||||
) -> None:
|
||||
|
|
@ -195,11 +196,11 @@ def _pipelined_multi_all_gather_and_consume(
|
|||
assert x.shape[0] * group_size == y.shape[0]
|
||||
assert x.shape[1:] == y.shape[1:]
|
||||
|
||||
def copy_shard(dst: List[torch.Tensor], src: List[torch.Tensor]) -> None:
|
||||
def copy_shard(dst: list[torch.Tensor], src: list[torch.Tensor]) -> None:
|
||||
for d, s in zip(dst, src):
|
||||
d.copy_(s)
|
||||
|
||||
def get_p2p_bufs(remote_rank: int) -> List[torch.Tensor]:
|
||||
def get_p2p_bufs(remote_rank: int) -> list[torch.Tensor]:
|
||||
offset_bytes = 0
|
||||
bufs = []
|
||||
for x in shard:
|
||||
|
|
@ -216,7 +217,7 @@ def _pipelined_multi_all_gather_and_consume(
|
|||
local_p2p_bufs = get_p2p_bufs(rank)
|
||||
|
||||
# shards[i] => shard from rank i
|
||||
shards: List[List[torch.Tensor]] = [[] for _ in range(group_size)]
|
||||
shards: list[list[torch.Tensor]] = [[] for _ in range(group_size)]
|
||||
for x in ag_out:
|
||||
for i, y in enumerate(x.chunk(group_size)):
|
||||
shards[i].append(y)
|
||||
|
|
@ -311,7 +312,7 @@ def _pipelined_all_gather_and_consume(
|
|||
shard_consumer(shard, src_rank)
|
||||
"""
|
||||
|
||||
def adapter(shard: List[torch.Tensor], rank: int) -> None:
|
||||
def adapter(shard: list[torch.Tensor], rank: int) -> None:
|
||||
shard_consumer(shard[0], rank)
|
||||
|
||||
_pipelined_multi_all_gather_and_consume(
|
||||
|
|
@ -505,14 +506,14 @@ def _check_and_verify_fp8_all_gather_scale_mode(
|
|||
def _fused_all_gather_matmul_impl(
|
||||
mm_out_op: torch._ops.OpOverload,
|
||||
A_shard: torch.Tensor,
|
||||
Bs: List[torch.Tensor],
|
||||
Bs: list[torch.Tensor],
|
||||
A_scale: Optional[torch.Tensor],
|
||||
kwargs_list: List[Dict[str, Any]],
|
||||
out_dtypes: List[Optional[torch.dtype]],
|
||||
kwargs_list: list[dict[str, Any]],
|
||||
out_dtypes: list[Optional[torch.dtype]],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
return_A: bool,
|
||||
) -> tuple[Optional[torch.Tensor], List[torch.Tensor]]:
|
||||
) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]:
|
||||
if A_shard.dim() < 2:
|
||||
raise ValueError("A_shard must be a matrix")
|
||||
for B in Bs:
|
||||
|
|
@ -563,7 +564,7 @@ def _fused_all_gather_matmul_impl(
|
|||
A_scale_shard.shape[1],
|
||||
)
|
||||
|
||||
def row_wise_sharded_consumer(shard: List[torch.Tensor], rank: int) -> None:
|
||||
def row_wise_sharded_consumer(shard: list[torch.Tensor], rank: int) -> None:
|
||||
for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)):
|
||||
mm_out_op(
|
||||
shard[0],
|
||||
|
|
@ -630,12 +631,12 @@ def _fused_all_gather_matmul_impl(
|
|||
@torch.library.impl(lib, "fused_all_gather_matmul", "Meta")
|
||||
def _fused_all_gather_matmul_fallback(
|
||||
A_shard: torch.Tensor,
|
||||
Bs: List[torch.Tensor],
|
||||
Bs: list[torch.Tensor],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
*,
|
||||
return_A: bool = True,
|
||||
) -> tuple[Optional[torch.Tensor], List[torch.Tensor]]:
|
||||
) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]:
|
||||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
A = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
A_shard.contiguous(), group_size, group_name
|
||||
|
|
@ -652,12 +653,12 @@ def _fused_all_gather_matmul_fallback(
|
|||
@torch.library.impl(lib, "fused_all_gather_matmul", "CUDA")
|
||||
def _fused_all_gather_matmul(
|
||||
A_shard: torch.Tensor,
|
||||
Bs: List[torch.Tensor],
|
||||
Bs: list[torch.Tensor],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
*,
|
||||
return_A: bool = True,
|
||||
) -> tuple[Optional[torch.Tensor], List[torch.Tensor]]:
|
||||
) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]:
|
||||
"""
|
||||
Perform the following logic with micro-pipelined computation and
|
||||
communication:
|
||||
|
|
@ -703,7 +704,7 @@ def _fused_all_gather_matmul(
|
|||
|
||||
def _should_use_fused_all_gather_matmul_native(
|
||||
A_shard: torch.Tensor,
|
||||
Bs: List[torch.Tensor],
|
||||
Bs: list[torch.Tensor],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
) -> bool:
|
||||
|
|
@ -806,9 +807,9 @@ def _should_use_multimem_all_gather_matmul(
|
|||
|
||||
def _multimem_all_gather_matmul(
|
||||
A_shard: torch.Tensor,
|
||||
Bs: List[torch.Tensor],
|
||||
Bs: list[torch.Tensor],
|
||||
group_name: str,
|
||||
) -> List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
group = c10d._resolve_process_group(group_name)
|
||||
A_shape = torch.Size((A_shard.shape[0] * group.size(), *A_shard.shape[1:]))
|
||||
symm_mem = get_symm_mem_workspace(
|
||||
|
|
@ -822,16 +823,16 @@ def _multimem_all_gather_matmul(
|
|||
@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta")
|
||||
def _fused_all_gather_scaled_matmul_fallback(
|
||||
A_shard: torch.Tensor,
|
||||
Bs: List[torch.Tensor],
|
||||
Bs: list[torch.Tensor],
|
||||
A_scale: torch.Tensor,
|
||||
B_scales: List[torch.Tensor],
|
||||
B_scales: list[torch.Tensor],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
biases: List[Optional[torch.Tensor]],
|
||||
result_scales: List[Optional[torch.Tensor]],
|
||||
out_dtypes: List[Optional[torch.dtype]],
|
||||
use_fast_accum: List[bool],
|
||||
) -> tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
biases: list[Optional[torch.Tensor]],
|
||||
result_scales: list[Optional[torch.Tensor]],
|
||||
out_dtypes: list[Optional[torch.dtype]],
|
||||
use_fast_accum: list[bool],
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes)
|
||||
|
||||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
|
|
@ -896,16 +897,16 @@ def _fused_all_gather_scaled_matmul_fallback(
|
|||
@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA")
|
||||
def _fused_all_gather_scaled_matmul(
|
||||
A_shard: torch.Tensor,
|
||||
Bs: List[torch.Tensor],
|
||||
Bs: list[torch.Tensor],
|
||||
A_scale: torch.Tensor,
|
||||
B_scales: List[torch.Tensor],
|
||||
B_scales: list[torch.Tensor],
|
||||
gather_dim: int,
|
||||
group_name: str,
|
||||
biases: List[Optional[torch.Tensor]],
|
||||
result_scales: List[Optional[torch.Tensor]],
|
||||
out_dtypes: List[Optional[torch.dtype]],
|
||||
use_fast_accum: List[bool],
|
||||
) -> tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
biases: list[Optional[torch.Tensor]],
|
||||
result_scales: list[Optional[torch.Tensor]],
|
||||
out_dtypes: list[Optional[torch.dtype]],
|
||||
use_fast_accum: list[bool],
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""
|
||||
Perform the following logic with micro-pipelined computation and
|
||||
communication:
|
||||
|
|
@ -976,7 +977,7 @@ def _fused_all_gather_scaled_matmul(
|
|||
|
||||
def make_contiguous_for_perm(
|
||||
t: torch.Tensor,
|
||||
perm: List[int],
|
||||
perm: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Restride `t` such that `t.permute(perm)` is contiguous.
|
||||
|
|
@ -1005,7 +1006,7 @@ def _fused_matmul_reduce_scatter_impl(
|
|||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
out_dtype: Optional[torch.dtype],
|
||||
reduce_op: str,
|
||||
scatter_dim: int,
|
||||
|
|
@ -1237,8 +1238,8 @@ def restride_A_for_fused_matmul_reduce_scatter(
|
|||
|
||||
|
||||
def _maybe_convert_scalar_types_to_dtypes(
|
||||
scalar_types: List[Any],
|
||||
) -> List[Optional[torch.dtype]]:
|
||||
scalar_types: list[Any],
|
||||
) -> list[Optional[torch.dtype]]:
|
||||
"""
|
||||
When a list of `torch.dtype`s is passed through the dispatcher as
|
||||
`ScalarType[]`, it is converted to a list of scalar type enum values. This
|
||||
|
|
@ -1270,7 +1271,7 @@ def _maybe_convert_scalar_types_to_dtypes(
|
|||
if any(not isinstance(x, (type(None), int)) for x in scalar_types):
|
||||
return scalar_types
|
||||
|
||||
dtypes: List[Optional[torch.dtype]] = []
|
||||
dtypes: list[Optional[torch.dtype]] = []
|
||||
for scalar_type in scalar_types:
|
||||
if scalar_type is None:
|
||||
dtypes.append(scalar_type)
|
||||
|
|
@ -1507,7 +1508,8 @@ def _low_contention_reduce_scatter(
|
|||
# =============================================================================
|
||||
|
||||
|
||||
from typing import Any, overload, Sequence, TYPE_CHECKING, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, overload, TYPE_CHECKING, Union
|
||||
|
||||
from torch.types import _device, _dtype, _int
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
|
||||
|
||||
import torch
|
||||
|
|
@ -111,9 +111,9 @@ class _FSDPModMemStats:
|
|||
|
||||
def __init__(self, mod_fqn: str) -> None:
|
||||
self.mod_fqn = mod_fqn
|
||||
self.local_peak: Dict[torch.device, int] = {}
|
||||
self.snapshots: Dict[
|
||||
_FSDPModState, List[Dict[torch.device, Dict[str, int]]]
|
||||
self.local_peak: dict[torch.device, int] = {}
|
||||
self.snapshots: dict[
|
||||
_FSDPModState, list[dict[torch.device, dict[str, int]]]
|
||||
] = {}
|
||||
|
||||
|
||||
|
|
@ -169,7 +169,7 @@ class FSDPMemTracker(MemTracker):
|
|||
self._in_fake_mode: bool = False
|
||||
self._fsdp_mod_to_saved_methods: WeakIdKeyDictionary = WeakIdKeyDictionary()
|
||||
self._saved_collectives: _SavedCollectives
|
||||
self._ref_class: Type[_RefType] = _FSDPRefType
|
||||
self._ref_class: type[_RefType] = _FSDPRefType
|
||||
|
||||
def _instrument_fsdp_sharded_params_grads(
|
||||
self, fsdp_param_group: FSDPParamGroup
|
||||
|
|
@ -190,8 +190,8 @@ class FSDPMemTracker(MemTracker):
|
|||
def _fsdp_state_pre_forward(
|
||||
self,
|
||||
fsdp_mod: FSDPModule,
|
||||
orig_fsdp_state_pre_fw: Callable[_P, tuple[tuple[Unpack[_Ts]], Dict[str, Any]]],
|
||||
) -> Callable[_P, tuple[tuple[Unpack[_Ts]], Dict[str, Any]]]:
|
||||
orig_fsdp_state_pre_fw: Callable[_P, tuple[tuple[Unpack[_Ts]], dict[str, Any]]],
|
||||
) -> Callable[_P, tuple[tuple[Unpack[_Ts]], dict[str, Any]]]:
|
||||
# We capture memory snapshots before and after ``FSDPState._pre_forward`` to attribute the `unsharded` params
|
||||
# and `all_gather` buffers. There are three cases:
|
||||
# Case 1: If the module is not in the ``memory_tracking`` dictionary, create a new ``_FSDPModMemStats``
|
||||
|
|
@ -208,7 +208,7 @@ class FSDPMemTracker(MemTracker):
|
|||
@wraps(orig_fsdp_state_pre_fw)
|
||||
def inner(
|
||||
*args: _P.args, **kwargs: _P.kwargs
|
||||
) -> tuple[tuple[Unpack[_Ts]], Dict[str, Any]]:
|
||||
) -> tuple[tuple[Unpack[_Ts]], dict[str, Any]]:
|
||||
mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod)
|
||||
assert mod_fqn is not None
|
||||
if fsdp_mod not in self.memory_tracking:
|
||||
|
|
@ -538,7 +538,7 @@ class FSDPMemTracker(MemTracker):
|
|||
def barrier(
|
||||
group: Union[ProcessGroup, None] = dist.GroupMember.WORLD,
|
||||
async_op: bool = False,
|
||||
device_ids: Union[List[int], None] = None,
|
||||
device_ids: Union[list[int], None] = None,
|
||||
) -> Union[Work, None]:
|
||||
if self._in_fake_mode:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import copy
|
||||
from typing import cast, Dict, List, OrderedDict, TypedDict
|
||||
from collections import OrderedDict
|
||||
from typing import cast, TypedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -15,10 +16,10 @@ from torch.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStat
|
|||
|
||||
|
||||
class ModOrder(TypedDict):
|
||||
fw_pre_order: List[str]
|
||||
bw_pre_order: List[str]
|
||||
fw_post_order: List[str]
|
||||
bw_post_order: List[str]
|
||||
fw_pre_order: list[str]
|
||||
bw_pre_order: list[str]
|
||||
fw_post_order: list[str]
|
||||
bw_post_order: list[str]
|
||||
|
||||
|
||||
class ModRuntime(TypedDict):
|
||||
|
|
@ -60,18 +61,18 @@ class ModStats(TypedDict):
|
|||
# Number of piecewise-linear functions used for approximating ac tradeoff curve
|
||||
n_segments: int
|
||||
# Slopes of the of piecewise-linear functions
|
||||
slopes: List[float]
|
||||
slopes: list[float]
|
||||
# Intercepts of the of piecewise-linear functions
|
||||
intercepts: List[float]
|
||||
intercepts: list[float]
|
||||
# X breakpoints of the of piecewise-linear functions
|
||||
breakpoints: List[float]
|
||||
breakpoints: list[float]
|
||||
# Original trade-off curves
|
||||
tradeoff_curve: OrderedDict[float, float]
|
||||
|
||||
|
||||
class ModuleInfo(TypedDict):
|
||||
mod_order: ModOrder
|
||||
mod_stats: List[ModStats]
|
||||
mod_stats: list[ModStats]
|
||||
|
||||
|
||||
def aggregate_stats(
|
||||
|
|
@ -96,12 +97,12 @@ def aggregate_stats(
|
|||
"""
|
||||
|
||||
# Memory stats
|
||||
mod_mem_stats: Dict[torch.nn.Module, _ModMemStats] = dict(
|
||||
mod_mem_stats: dict[torch.nn.Module, _ModMemStats] = dict(
|
||||
copy.deepcopy(mem_tracker.memory_tracking)
|
||||
)
|
||||
|
||||
# Runtime stats
|
||||
mod_runtime_stats: Dict[str, ModRuntime] = {
|
||||
mod_runtime_stats: dict[str, ModRuntime] = {
|
||||
fqn: {"fw": v["fw"], "bw": v["bw"]}
|
||||
for fqn, v in runtime_estimator.mod_runtimes.items()
|
||||
}
|
||||
|
|
@ -116,7 +117,7 @@ def aggregate_stats(
|
|||
|
||||
# Selective Activation Checkpointing stats
|
||||
sac_estimator.pwlf_sac_tradeoff_curve()
|
||||
mod_sac_tradeoff_stats: Dict[str, SACTradeOffStats] = copy.deepcopy(
|
||||
mod_sac_tradeoff_stats: dict[str, SACTradeOffStats] = copy.deepcopy(
|
||||
sac_estimator.sac_mod_tradeoff_stats
|
||||
)
|
||||
|
||||
|
|
@ -192,10 +193,10 @@ class Node(ModStats):
|
|||
|
||||
class Graph:
|
||||
def __init__(self, n: int) -> None:
|
||||
self.nodes: List[Node] = []
|
||||
self.name2node: Dict[str, Node] = {}
|
||||
self.nodes: list[Node] = []
|
||||
self.name2node: dict[str, Node] = {}
|
||||
self.ad_matrix = np.zeros((n, n))
|
||||
self.fw_post_order: List[str] = []
|
||||
self.fw_post_order: list[str] = []
|
||||
|
||||
def add_node(self, node: Node) -> None:
|
||||
self.nodes.append(node)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import warnings
|
|||
from copy import deepcopy
|
||||
from enum import auto, Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Type, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
|
|
@ -116,8 +116,8 @@ class _ModMemStats:
|
|||
self.buffer_mem: int
|
||||
self.input_mem: int
|
||||
self.output_mem: int
|
||||
self.local_peak: Dict[torch.device, int] = {}
|
||||
self.snapshots: Dict[_ModState, List[Dict[torch.device, Dict[str, int]]]] = {}
|
||||
self.local_peak: dict[torch.device, int] = {}
|
||||
self.snapshots: dict[_ModState, list[dict[torch.device, dict[str, int]]]] = {}
|
||||
|
||||
|
||||
class _WeakRefInfo:
|
||||
|
|
@ -171,7 +171,7 @@ class _WeakRefInfo:
|
|||
return self.mem_consumed
|
||||
|
||||
@staticmethod
|
||||
def get_untyped_storages(t: torch.Tensor) -> Set[torch.UntypedStorage]:
|
||||
def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
|
||||
"""
|
||||
Recursively extracts untyped storages from a tensor or its subclasses.
|
||||
|
||||
|
|
@ -242,7 +242,7 @@ def _rounding_fn(value: int, divisor: int, precision: int) -> Union[float, int]:
|
|||
return value if divisor == 1 else round(value / divisor, precision)
|
||||
|
||||
|
||||
def _print_snapshot(snapshot: Dict[torch.device, Dict[str, int]], units: str) -> None:
|
||||
def _print_snapshot(snapshot: dict[torch.device, dict[str, int]], units: str) -> None:
|
||||
if len(snapshot) == 0:
|
||||
print("No memory tracked.")
|
||||
return
|
||||
|
|
@ -261,7 +261,7 @@ def _print_snapshot(snapshot: Dict[torch.device, Dict[str, int]], units: str) ->
|
|||
|
||||
|
||||
def _print_snapshot_tabular(
|
||||
snapshot: Dict[torch.device, Dict[str, int]], units: str
|
||||
snapshot: dict[torch.device, dict[str, int]], units: str
|
||||
) -> None:
|
||||
if len(snapshot) == 0:
|
||||
print("No memory tracked.")
|
||||
|
|
@ -287,7 +287,7 @@ def _print_snapshot_tabular(
|
|||
|
||||
|
||||
def _print_state_snapshots(
|
||||
snapshots: Dict[_State, List[Dict[torch.device, Dict[str, int]]]], units: str
|
||||
snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str
|
||||
) -> None:
|
||||
for state, snapshot_list in snapshots.items():
|
||||
print(f"{state}")
|
||||
|
|
@ -298,7 +298,7 @@ def _print_state_snapshots(
|
|||
|
||||
|
||||
def _print_state_snapshots_tabular(
|
||||
snapshots: Dict[_State, List[Dict[torch.device, Dict[str, int]]]], units: str
|
||||
snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str
|
||||
) -> None:
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
|
|
@ -393,9 +393,9 @@ class MemTracker(TorchDispatchMode):
|
|||
|
||||
def __init__(self) -> None:
|
||||
self.memory_tracking = WeakIdKeyDictionary()
|
||||
self._curr_mem_snap: Dict[torch.device, Dict[str, int]] = {}
|
||||
self._peak_mem: Dict[torch.device, int] = {}
|
||||
self._peak_mem_snap: Dict[torch.device, Dict[str, int]] = {}
|
||||
self._curr_mem_snap: dict[torch.device, dict[str, int]] = {}
|
||||
self._peak_mem: dict[torch.device, int] = {}
|
||||
self._peak_mem_snap: dict[torch.device, dict[str, int]] = {}
|
||||
self._param_to_grad_hook_handles = WeakIdKeyDictionary()
|
||||
self._optimizer_hook_handles: Optional[
|
||||
tuple[RemovableHandle, RemovableHandle]
|
||||
|
|
@ -404,7 +404,7 @@ class MemTracker(TorchDispatchMode):
|
|||
self._WINFO = WeakIdKeyDictionary()
|
||||
self._mod_tracker = ModTracker()
|
||||
# This is a general memory tracker which can be used with any ``_RefType`` subclass
|
||||
self._ref_class: Type[_RefType] = _MemRefType
|
||||
self._ref_class: type[_RefType] = _MemRefType
|
||||
# Flags to track if we are in the AC region or optimizer step region
|
||||
self._in_opt: bool = False
|
||||
self._in_ac: bool = False
|
||||
|
|
@ -461,7 +461,7 @@ class MemTracker(TorchDispatchMode):
|
|||
t: torch.Tensor,
|
||||
reftype: _RefType,
|
||||
update_existing: bool = False,
|
||||
) -> Set[_WeakRefInfo]:
|
||||
) -> set[_WeakRefInfo]:
|
||||
sts = _WeakRefInfo.get_untyped_storages(t)
|
||||
winfos = set()
|
||||
for st in sts:
|
||||
|
|
@ -565,7 +565,7 @@ class MemTracker(TorchDispatchMode):
|
|||
|
||||
def get_tracker_snapshot(
|
||||
self, type: str = "current"
|
||||
) -> Dict[torch.device, Dict[str, int]]:
|
||||
) -> dict[torch.device, dict[str, int]]:
|
||||
"""
|
||||
Capture a snapshot of the memory usage breakdown per device, based on the specified type.
|
||||
|
||||
|
|
@ -865,7 +865,7 @@ class MemTracker(TorchDispatchMode):
|
|||
tabulate (bool, optional): Whether to display the snapshot in a tabular format. Defaults to False.
|
||||
"""
|
||||
|
||||
def natural_sort_key(s: str) -> List[Union[int, str]]:
|
||||
def natural_sort_key(s: str) -> list[Union[int, str]]:
|
||||
return [
|
||||
int(text) if text.isdigit() else text.lower()
|
||||
for text in re.split("([0-9]+)", s)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@
|
|||
import operator
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Dict, List, no_type_check, Sequence, TYPE_CHECKING
|
||||
from typing import Any, Callable, no_type_check, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -72,12 +73,12 @@ class MemoryTracker:
|
|||
|
||||
def __init__(self) -> None:
|
||||
torch._C._log_api_usage_once("torch.distributed.memory_tracker")
|
||||
self._hooks: List[RemovableHandle] = []
|
||||
self._operator_names: Dict[str, int] = defaultdict(int)
|
||||
self.memories_allocated: Dict[int, Dict[str, float]] = defaultdict()
|
||||
self.memories_active: Dict[int, Dict[str, float]] = defaultdict()
|
||||
self.memories_reserved: Dict[int, Dict[str, float]] = defaultdict()
|
||||
self._markers: Dict[str, int] = defaultdict(int)
|
||||
self._hooks: list[RemovableHandle] = []
|
||||
self._operator_names: dict[str, int] = defaultdict(int)
|
||||
self.memories_allocated: dict[int, dict[str, float]] = defaultdict()
|
||||
self.memories_active: dict[int, dict[str, float]] = defaultdict()
|
||||
self.memories_reserved: dict[int, dict[str, float]] = defaultdict()
|
||||
self._markers: dict[str, int] = defaultdict(int)
|
||||
self._cur_module_name: str = ""
|
||||
self._op_index: int = 0
|
||||
self._num_cuda_retries: int = 0
|
||||
|
|
@ -133,7 +134,7 @@ class MemoryTracker:
|
|||
|
||||
The number of the top operators can be configured.
|
||||
"""
|
||||
op_diff: Dict[str, float] = defaultdict(float)
|
||||
op_diff: dict[str, float] = defaultdict(float)
|
||||
op_name, previous_allocated_memory = self.memories_allocated[0]
|
||||
for i in range(1, self._op_index):
|
||||
op_name, current_allocated_memory = self.memories_allocated[i]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
import weakref
|
||||
from typing import Callable, Optional, Set
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.autograd.graph import register_multi_grad_hook
|
||||
|
|
@ -48,7 +48,7 @@ class ModTracker:
|
|||
|
||||
"""
|
||||
|
||||
parents: Set[str]
|
||||
parents: set[str]
|
||||
"""
|
||||
A Set containing the fqn for each module currently running their forward
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Set
|
||||
from typing import Any, Callable
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
|
|
@ -121,13 +121,13 @@ class RuntimeEstimator(TorchDispatchMode):
|
|||
runtime_estimator.display_modulewise_stats()
|
||||
"""
|
||||
|
||||
_float_types: Set[torch.dtype] = {
|
||||
_float_types: set[torch.dtype] = {
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
}
|
||||
_no_fallback_kernel: Set[torch._ops._OpNamespace] = set()
|
||||
_no_fallback_kernel: set[torch._ops._OpNamespace] = set()
|
||||
fake_mode: FakeTensorMode
|
||||
|
||||
def __init__(self) -> None:
|
||||
|
|
@ -135,13 +135,13 @@ class RuntimeEstimator(TorchDispatchMode):
|
|||
self._estimate: Callable
|
||||
self._estimate_mode_type: str
|
||||
self._mod_tracker = ModTracker()
|
||||
self.mod_runtimes: Dict[str, Dict[str, float]] = defaultdict(
|
||||
self.mod_runtimes: dict[str, dict[str, float]] = defaultdict(
|
||||
lambda: defaultdict(lambda: 0.0)
|
||||
)
|
||||
self.mod_fw_pre_order: List[str] = []
|
||||
self.mod_bw_pre_order: List[str] = []
|
||||
self.mod_fw_post_order: List[str] = []
|
||||
self.mod_bw_post_order: List[str] = []
|
||||
self.mod_fw_pre_order: list[str] = []
|
||||
self.mod_bw_pre_order: list[str] = []
|
||||
self.mod_fw_post_order: list[str] = []
|
||||
self.mod_bw_post_order: list[str] = []
|
||||
self.total_runtime: float = 0.0
|
||||
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import astuple, dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Set
|
||||
from typing import Any, NamedTuple, Optional
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
|
|
@ -42,7 +42,7 @@ _PYTORCH_MIN_ALLOCATE = (
|
|||
)
|
||||
|
||||
|
||||
def _get_untyped_storages(t: torch.Tensor) -> Set[torch.UntypedStorage]:
|
||||
def _get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
|
||||
"""
|
||||
Retrieves untyped storages from a `torch.Tensor` or one of its traceable wrapper-subclass.
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ def _get_untyped_storages(t: torch.Tensor) -> Set[torch.UntypedStorage]:
|
|||
return flattened_tensor_storages
|
||||
|
||||
|
||||
def _display_stats_tabular(headers: List[str], table_data: List[List[Any]]) -> None:
|
||||
def _display_stats_tabular(headers: list[str], table_data: list[list[Any]]) -> None:
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ImportError as err:
|
||||
|
|
@ -125,7 +125,7 @@ class _SACModMetadata:
|
|||
|
||||
start_idx: int
|
||||
force_store_random: bool
|
||||
sac_metadata: List[_SACMetadata]
|
||||
sac_metadata: list[_SACMetadata]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -144,13 +144,13 @@ class SACStats:
|
|||
force_store_random (bool): Whether to force store random operator results.
|
||||
"""
|
||||
|
||||
func_names: List[str]
|
||||
runtimes: List[float]
|
||||
memory: List[int]
|
||||
view_like_ops: List[int]
|
||||
rand_ops: List[int]
|
||||
saved_autograd_ops: List[int]
|
||||
inplace_ops: List[tuple[int, int]]
|
||||
func_names: list[str]
|
||||
runtimes: list[float]
|
||||
memory: list[int]
|
||||
view_like_ops: list[int]
|
||||
rand_ops: list[int]
|
||||
saved_autograd_ops: list[int]
|
||||
inplace_ops: list[tuple[int, int]]
|
||||
force_store_random: bool
|
||||
|
||||
|
||||
|
|
@ -166,7 +166,7 @@ class MSPS(NamedTuple):
|
|||
msps (float): Memory per second calculated as memory/runtime.
|
||||
"""
|
||||
|
||||
func_names: Set[str]
|
||||
func_names: set[str]
|
||||
op_idx: int
|
||||
memory: int
|
||||
runtime: float
|
||||
|
|
@ -189,9 +189,9 @@ class SACTradeOffStats:
|
|||
"""
|
||||
|
||||
n_segments: int
|
||||
slopes: List[float]
|
||||
intercepts: List[float]
|
||||
fit_breaks: List[float]
|
||||
slopes: list[float]
|
||||
intercepts: list[float]
|
||||
fit_breaks: list[float]
|
||||
tradeoff_curve: OrderedDict[float, float]
|
||||
sac_memory: int
|
||||
sac_runtime: float
|
||||
|
|
@ -210,11 +210,11 @@ class SACGreedyOrderMeta:
|
|||
msps_meta (List[MSPS]): List of Memory and Runtime Statistics for operators.
|
||||
"""
|
||||
|
||||
recomputed_ops: Set[int]
|
||||
stored_ops: Set[int]
|
||||
inplace_op_groups: Dict[int, Set[int]]
|
||||
random_ops_group: Dict[int, Set[int]]
|
||||
msps_meta: List[MSPS]
|
||||
recomputed_ops: set[int]
|
||||
stored_ops: set[int]
|
||||
inplace_op_groups: dict[int, set[int]]
|
||||
random_ops_group: dict[int, set[int]]
|
||||
msps_meta: list[MSPS]
|
||||
|
||||
|
||||
class SACEstimator(TorchDispatchMode):
|
||||
|
|
@ -251,17 +251,17 @@ class SACEstimator(TorchDispatchMode):
|
|||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sac_mod_stats: Dict[str, SACStats] = {}
|
||||
self.sac_mod_tradeoff_stats: Dict[str, SACTradeOffStats] = {}
|
||||
self.sac_mod_greedy_order_meta: Dict[str, SACGreedyOrderMeta] = {}
|
||||
self.sac_mod_stats: dict[str, SACStats] = {}
|
||||
self.sac_mod_tradeoff_stats: dict[str, SACTradeOffStats] = {}
|
||||
self.sac_mod_greedy_order_meta: dict[str, SACGreedyOrderMeta] = {}
|
||||
self._mod_tracker = ModTracker()
|
||||
self._sac_metadata: List[_SACMetadata] = []
|
||||
self._sac_mod_metadata: Dict[str, _SACModMetadata] = {}
|
||||
self._leaf_modules: Set[str] = set()
|
||||
self._sac_metadata: list[_SACMetadata] = []
|
||||
self._sac_mod_metadata: dict[str, _SACModMetadata] = {}
|
||||
self._leaf_modules: set[str] = set()
|
||||
self._saved_tensor_hook_ctx = torch.autograd.graph.saved_tensors_hooks(
|
||||
self._pack_hook, lambda x: x
|
||||
)
|
||||
self._saved_tensor_ids: Set[int] = set()
|
||||
self._saved_tensor_ids: set[int] = set()
|
||||
self._estimate_runtime = RuntimeEstimator._roofline_estimate
|
||||
|
||||
def _pack_hook(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -313,7 +313,7 @@ class SACEstimator(TorchDispatchMode):
|
|||
return all(not isinstance(x, torch.Tensor) for x in flat_inputs)
|
||||
|
||||
def _get_sac_stats(
|
||||
self, data: List[_SACMetadata], force_store_random: bool
|
||||
self, data: list[_SACMetadata], force_store_random: bool
|
||||
) -> SACStats:
|
||||
# 1. Ignore the operations that should be skipped by SAC such as aten.detach.default because autograd
|
||||
# inserts those during backward and it breaks the fwd-bwd alignment
|
||||
|
|
@ -381,12 +381,12 @@ class SACEstimator(TorchDispatchMode):
|
|||
)
|
||||
|
||||
def _get_inplace_metadata(
|
||||
self, func: Any, out_storages: Set[UntypedStorage]
|
||||
) -> tuple[int, tuple[int, ...], Dict[str, tuple[int, ...]]]:
|
||||
self, func: Any, out_storages: set[UntypedStorage]
|
||||
) -> tuple[int, tuple[int, ...], dict[str, tuple[int, ...]]]:
|
||||
# 1. Get the current index of the metadata obtained so far
|
||||
curr_idx = len(self._sac_metadata)
|
||||
# 2. Get the set of active modules that are not leaf
|
||||
active_mod_fqns: Set[str] = {
|
||||
active_mod_fqns: set[str] = {
|
||||
par for par in self._mod_tracker.parents if par not in self._leaf_modules
|
||||
}
|
||||
# 3. Output ids are the identifies of the storage objects corresponding to the tensors
|
||||
|
|
@ -397,7 +397,7 @@ class SACEstimator(TorchDispatchMode):
|
|||
|
||||
op_idx = curr_idx
|
||||
# 5. Initialize the parent op ids of the inplace op for each of the active modules
|
||||
mod_op_parent_idxs: Dict[str, int] = {
|
||||
mod_op_parent_idxs: dict[str, int] = {
|
||||
mod_fqn: -1 for mod_fqn in active_mod_fqns
|
||||
}
|
||||
for i, d in enumerate(self._sac_metadata):
|
||||
|
|
@ -430,9 +430,9 @@ class SACEstimator(TorchDispatchMode):
|
|||
# 1. Get the runtime estimate
|
||||
out, op_time = self._estimate_runtime(func, args, kwargs)
|
||||
flat_outs, _ = tree_flatten(out)
|
||||
out_storages_cuda: Set[UntypedStorage] = set()
|
||||
out_storages_cpu: Set[UntypedStorage] = set()
|
||||
cuda_devices: Set[torch.device] = set()
|
||||
out_storages_cuda: set[UntypedStorage] = set()
|
||||
out_storages_cpu: set[UntypedStorage] = set()
|
||||
cuda_devices: set[torch.device] = set()
|
||||
for o in flat_outs:
|
||||
if isinstance(o, torch.Tensor):
|
||||
if o.device.type == "cuda":
|
||||
|
|
@ -496,8 +496,8 @@ class SACEstimator(TorchDispatchMode):
|
|||
# 1. inplace_op_groups: A dictionary from the top-most parent of inplace-ops to the inplace-ops in the group
|
||||
# The top-most op can itself be an inplace-op or can be a non-inplace op.
|
||||
# 2. inplace_op_to_group_head: A dictionary that maps all the inplace-ops to their respective group heads.
|
||||
inplace_op_groups: Dict[int, Set[int]] = {}
|
||||
inplace_op_to_group_head: Dict[int, int] = dict(sac_stats.inplace_ops)
|
||||
inplace_op_groups: dict[int, set[int]] = {}
|
||||
inplace_op_to_group_head: dict[int, int] = dict(sac_stats.inplace_ops)
|
||||
|
||||
# Initialize inplace_op_groups using inplace_op_to_group_head
|
||||
for op_idx, group_head_idx in inplace_op_to_group_head.items():
|
||||
|
|
@ -508,7 +508,7 @@ class SACEstimator(TorchDispatchMode):
|
|||
# as a group. This is because, they affect the ranom seed generator. If force_store_random is set True,
|
||||
# all of the random ops will be stored by default. For easy of manageability, we store the top-most random op
|
||||
# as the leader of the random_ops_group.
|
||||
random_ops_group: Dict[int, Set[int]] = {}
|
||||
random_ops_group: dict[int, set[int]] = {}
|
||||
random_group_head_idx = min(sac_stats.rand_ops, default=-1)
|
||||
has_rand_ops = bool(sac_stats.rand_ops)
|
||||
if has_rand_ops:
|
||||
|
|
@ -521,8 +521,8 @@ class SACEstimator(TorchDispatchMode):
|
|||
# b) If any op in the group is random and force_store_random is set, then entire group will be stored.
|
||||
# c) If none of ops in the group are random and the head of the group is not an in-place op, then
|
||||
# this group can be considered for recomputation in its entireity
|
||||
stored_ops: Set[int] = set()
|
||||
recomputed_ops: Set[int] = set()
|
||||
stored_ops: set[int] = set()
|
||||
recomputed_ops: set[int] = set()
|
||||
# Case 1:
|
||||
if has_rand_ops and sac_stats.force_store_random:
|
||||
stored_ops.add(random_group_head_idx)
|
||||
|
|
@ -541,7 +541,7 @@ class SACEstimator(TorchDispatchMode):
|
|||
stored_ops.add(group_head_idx)
|
||||
|
||||
# The potential recompute candidates are populated as:
|
||||
recompute_candidates: Set[int] = set()
|
||||
recompute_candidates: set[int] = set()
|
||||
# 1) The random group head if it is not stored
|
||||
if has_rand_ops and random_group_head_idx not in stored_ops:
|
||||
recompute_candidates.add(random_group_head_idx)
|
||||
|
|
@ -557,7 +557,7 @@ class SACEstimator(TorchDispatchMode):
|
|||
)
|
||||
|
||||
# We define msps for a recomp candidate as the ratio of memory/runtime aka memory savings per second
|
||||
msps_meta: List[MSPS] = []
|
||||
msps_meta: list[MSPS] = []
|
||||
for cand_idx in recompute_candidates:
|
||||
op_indices = {cand_idx}
|
||||
if cand_idx in inplace_op_groups:
|
||||
|
|
@ -598,7 +598,7 @@ class SACEstimator(TorchDispatchMode):
|
|||
greedy_order_meta.msps_meta,
|
||||
)
|
||||
# 1. Intitialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops
|
||||
recomp_indices: Set[int] = set()
|
||||
recomp_indices: set[int] = set()
|
||||
for r_idx in recomputed_ops:
|
||||
recomp_indices.add(r_idx)
|
||||
if r_idx in inplace_op_groups:
|
||||
|
|
@ -628,7 +628,7 @@ class SACEstimator(TorchDispatchMode):
|
|||
recomp_runtime / sac_runtime
|
||||
)
|
||||
# 6. Finally, we add the memory and recomputation time of the always stored ops.
|
||||
stored_indices: Set[int] = set()
|
||||
stored_indices: set[int] = set()
|
||||
for s_idx in stored_ops:
|
||||
stored_indices.add(s_idx)
|
||||
if s_idx in inplace_op_groups:
|
||||
|
|
@ -653,7 +653,7 @@ class SACEstimator(TorchDispatchMode):
|
|||
|
||||
# save prediction graph
|
||||
def save_prediction_graph(
|
||||
pwlf_: pwlf.PiecewiseLinFit, x: List[float], y: List[float], filename: str
|
||||
pwlf_: pwlf.PiecewiseLinFit, x: list[float], y: list[float], filename: str
|
||||
) -> None:
|
||||
try:
|
||||
import matplotlib.pyplot as plt # type: ignore[import-not-found]
|
||||
|
|
@ -811,8 +811,8 @@ class SACEstimator(TorchDispatchMode):
|
|||
recomp_runtime: float = 0.0
|
||||
|
||||
def append_row(
|
||||
op_indices: Set[int],
|
||||
func_names: Set[str],
|
||||
op_indices: set[int],
|
||||
func_names: set[str],
|
||||
msps: Optional[float] = None,
|
||||
stored: Optional[bool] = False,
|
||||
recomputed: Optional[bool] = False,
|
||||
|
|
@ -839,7 +839,7 @@ class SACEstimator(TorchDispatchMode):
|
|||
)
|
||||
|
||||
for op_idx in recomputed_ops:
|
||||
op_indices: Set[int] = {op_idx}
|
||||
op_indices: set[int] = {op_idx}
|
||||
if op_idx in inplace_op_groups:
|
||||
op_indices.update(inplace_op_groups[op_idx])
|
||||
if op_idx in random_ops_group:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
import math
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from torch.distributed._tools.ilp_utils import Graph, is_submodule
|
||||
from torch.distributed._tools.sac_estimator import SACStats
|
||||
|
|
@ -36,9 +36,9 @@ def sac_milp(
|
|||
graph: Graph,
|
||||
memory_budget: float,
|
||||
world_size: int = 1,
|
||||
ac_units: Optional[List[str]] = None,
|
||||
fsdp_units: Optional[List[str]] = None,
|
||||
) -> tuple[Dict[str, float], float, int]:
|
||||
ac_units: Optional[list[str]] = None,
|
||||
fsdp_units: Optional[list[str]] = None,
|
||||
) -> tuple[dict[str, float], float, int]:
|
||||
"""
|
||||
MILP to decide which modules to AC and how much memory to discard.
|
||||
The objective is to minimize recomputation time.
|
||||
|
|
@ -224,7 +224,7 @@ class SACDecision(IntEnum):
|
|||
|
||||
def get_optimal_checkpointing_policy_per_module(
|
||||
sac_stats: SACStats, memory_budget: float
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
"""
|
||||
This is adapted from --
|
||||
https://github.com/facebookresearch/xformers/blob/c6c0ac31f1b08542a0bc27278c6ed10f825f6963/xformers/checkpoint.py#L375
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from enum import auto, Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Iterator, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -69,10 +70,10 @@ class ActivationWrapper(torch.nn.Module, ABC):
|
|||
@staticmethod
|
||||
def _post_state_dict_hook(
|
||||
module: nn.Module,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
*args: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
_post_state_dict_hook() is called after the state_dict() of this FSDP module is executed.
|
||||
|
||||
|
|
@ -87,7 +88,7 @@ class ActivationWrapper(torch.nn.Module, ABC):
|
|||
@staticmethod
|
||||
def _pre_load_state_dict_hook(
|
||||
module: nn.Module,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
*args: Any,
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Type
|
||||
|
||||
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import allreduce_hook
|
||||
from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
|
||||
|
|
@ -15,7 +14,7 @@ from torch.optim import Optimizer
|
|||
|
||||
|
||||
# Contains the mappings between the regular and overlapped optimizer types.
|
||||
_registered_overlapped_optims: Dict[Type, Type] = {}
|
||||
_registered_overlapped_optims: dict[type, type] = {}
|
||||
|
||||
|
||||
def register_overlapped(optim_cls):
|
||||
|
|
@ -33,7 +32,7 @@ def register_overlapped(optim_cls):
|
|||
|
||||
|
||||
class OverlappedOptimizer(ABC):
|
||||
def __init__(self, optim_cls: Type) -> None:
|
||||
def __init__(self, optim_cls: type) -> None:
|
||||
"""
|
||||
Initialize the OverlappedOptimizer.
|
||||
|
||||
|
|
@ -61,7 +60,7 @@ class OverlappedOptimizer(ABC):
|
|||
class _OverlappedStandardOptimizer(OverlappedOptimizer):
|
||||
"""Overlaps a regular ``Optimizer``."""
|
||||
|
||||
def __init__(self, optim_cls: Type, params, *optim_args, **optim_kwargs) -> None:
|
||||
def __init__(self, optim_cls: type, params, *optim_args, **optim_kwargs) -> None:
|
||||
super().__init__(optim_cls)
|
||||
f_optim = as_functional_optim(self.optim_cls, *optim_args, **optim_kwargs)
|
||||
self._opt_hook_state = _OptimizerHookState(f_optim, params)
|
||||
|
|
@ -82,7 +81,7 @@ class _OverlappedStandardOptimizer(OverlappedOptimizer):
|
|||
)
|
||||
|
||||
|
||||
def _as_overlapped_optim(optim_cls: Type, params, *args, **kwargs):
|
||||
def _as_overlapped_optim(optim_cls: type, params, *args, **kwargs):
|
||||
"""Return a new ``OverlappedOptimizer`` instance that supports ``optim_cls``."""
|
||||
for clz in inspect.getmro(optim_cls):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import weakref
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -46,7 +46,7 @@ def _perform_local_step(
|
|||
# expects `None` in a list position to indicate that the corresponding
|
||||
# parameter should not be updated
|
||||
num_local_optim_params = len(zero.optim.param_groups[0]["params"])
|
||||
gradients: List[Optional[torch.Tensor]] = [
|
||||
gradients: list[Optional[torch.Tensor]] = [
|
||||
_NO_PARAM_UPDATE for _ in range(num_local_optim_params)
|
||||
]
|
||||
assert (
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, no_type_check
|
||||
from typing import Any, Callable, no_type_check
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
_FUNCTIONAL_OPTIM_STEP_METHOD_NAME = "step_param"
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -252,9 +251,9 @@ class PowerSGDState:
|
|||
self.rng = np.random.RandomState(random_seed)
|
||||
# Since there is only a single state instance for all the input buckets,
|
||||
# need to maintain a dictionary that maps each bucket index to the local error.
|
||||
self.error_dict: Dict[int, torch.Tensor] = {}
|
||||
self.p_memory_dict: Dict[int, torch.Tensor] = {}
|
||||
self.q_memory_dict: Dict[int, torch.Tensor] = {}
|
||||
self.error_dict: dict[int, torch.Tensor] = {}
|
||||
self.p_memory_dict: dict[int, torch.Tensor] = {}
|
||||
self.q_memory_dict: dict[int, torch.Tensor] = {}
|
||||
# Iteration/step in the training loop.
|
||||
self.iter = 0
|
||||
# Compression stats accumulators
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from types import TracebackType
|
||||
from typing import Any, List, NamedTuple, Optional, Type
|
||||
from typing import Any, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -165,7 +165,7 @@ class Join:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
joinables: List[Joinable],
|
||||
joinables: list[Joinable],
|
||||
enable: bool = True,
|
||||
throw_on_early_termination: bool = False,
|
||||
**kwargs,
|
||||
|
|
@ -228,7 +228,7 @@ class Join:
|
|||
|
||||
def __exit__(
|
||||
self,
|
||||
type: Optional[Type[BaseException]],
|
||||
type: Optional[type[BaseException]],
|
||||
value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Iterable, Optional, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -107,7 +108,7 @@ class PeriodicModelAverager(ModelAverager):
|
|||
def average_parameters(
|
||||
self,
|
||||
params: Union[
|
||||
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
|
||||
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@
|
|||
import logging
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Iterable, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -159,7 +160,7 @@ class HierarchicalModelAverager(averagers.ModelAverager):
|
|||
def average_parameters(
|
||||
self,
|
||||
params: Union[
|
||||
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
|
||||
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
|
||||
],
|
||||
):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# flake8: noqa C101
|
||||
import itertools
|
||||
from typing import Dict, Iterable, Iterator, Union
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -51,7 +52,7 @@ def average_parameters(
|
|||
|
||||
|
||||
def get_params_to_average(
|
||||
params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]
|
||||
params: Union[Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]]
|
||||
):
|
||||
"""
|
||||
Return a list of parameters that need to average.
|
||||
|
|
@ -81,7 +82,7 @@ def get_params_to_average(
|
|||
|
||||
def average_parameters_or_parameter_groups(
|
||||
params: Union[
|
||||
Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]
|
||||
Iterable[torch.nn.Parameter], Iterable[dict[str, torch.nn.Parameter]]
|
||||
],
|
||||
process_group: ProcessGroup,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, TypeVar
|
||||
from typing import Any, Callable, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
|
|
@ -18,7 +18,7 @@ from torch.distributed.logging_handlers import _log_handlers
|
|||
from torch.monitor import _WaitCounter
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
_DEFAULT_DESTINATION = "default"
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ global _c10d_logger
|
|||
_c10d_logger = _get_or_create_logger()
|
||||
|
||||
|
||||
def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
|
||||
def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]:
|
||||
if dist.is_initialized():
|
||||
group = kwargs.get("group") or kwargs.get("process_group")
|
||||
msg_dict = {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ Each should also handle single rank scenario.
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast, Generic, List, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, cast, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ def all_gather(
|
|||
data_or_fn: Union[T, Callable[[], T]],
|
||||
stage_name: Optional[str] = None,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
) -> List[T]:
|
||||
) -> list[T]:
|
||||
"""
|
||||
A simple all_gather primitive with basic synchronization guard logic,
|
||||
by checking payload from all ranks has the same stage name.
|
||||
|
|
@ -149,11 +149,11 @@ def all_gather(
|
|||
all_gather_object_enforce_type(pg, total_list, sync_obj)
|
||||
# Each rank will throw RuntimeError in case of failure on any rank.
|
||||
stage_name = cast(SyncPayload[T], total_list[0]).stage_name
|
||||
exception_list: List[tuple[int, Exception]] = []
|
||||
ret_list: List[T] = []
|
||||
exception_list: list[tuple[int, Exception]] = []
|
||||
ret_list: list[T] = []
|
||||
error_msg: str = ""
|
||||
|
||||
for i, sp in enumerate(cast(List[SyncPayload[T]], total_list)):
|
||||
for i, sp in enumerate(cast(list[SyncPayload[T]], total_list)):
|
||||
if sp.stage_name != stage_name:
|
||||
error_msg += (
|
||||
f"Unexpected stage name received from rank {i}: {sp.stage_name} "
|
||||
|
|
@ -183,7 +183,7 @@ def all_gather(
|
|||
def all_gather_object_enforce_type(
|
||||
pg: dist.ProcessGroup,
|
||||
# pyre-fixme[2]: Parameter must have a type that does not contain `Any`
|
||||
object_list: List[Any],
|
||||
object_list: list[Any],
|
||||
# pyre-fixme[2]: Parameter must have a type other than `Any`
|
||||
obj: Any,
|
||||
# pyre-fixme[2]: Parameter must have a type that does not contain `Any`
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import math
|
|||
import threading
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import is_available
|
||||
|
|
@ -65,15 +65,15 @@ else:
|
|||
|
||||
class _MeshEnv(threading.local):
|
||||
def __init__(self) -> None:
|
||||
self.mesh_stack: List[DeviceMesh] = []
|
||||
self.child_to_root_mapping: Dict[DeviceMesh, DeviceMesh] = {}
|
||||
self.mesh_dim_group_options: Dict[
|
||||
self.mesh_stack: list[DeviceMesh] = []
|
||||
self.child_to_root_mapping: dict[DeviceMesh, DeviceMesh] = {}
|
||||
self.mesh_dim_group_options: dict[
|
||||
int, tuple[str, Optional[C10dBackend.Options]]
|
||||
] = {}
|
||||
self.root_to_flatten_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {}
|
||||
self.root_to_flatten_mapping: dict[DeviceMesh, dict[str, DeviceMesh]] = {}
|
||||
# Record flatten mesh name to its mesh dim index in root mesh.
|
||||
self.flatten_name_to_root_dims: Dict[
|
||||
DeviceMesh, Dict[str, tuple[int, ...]]
|
||||
self.flatten_name_to_root_dims: dict[
|
||||
DeviceMesh, dict[str, tuple[int, ...]]
|
||||
] = {}
|
||||
|
||||
def get_current_mesh(self) -> "DeviceMesh":
|
||||
|
|
@ -85,7 +85,7 @@ else:
|
|||
self,
|
||||
device_mesh: "DeviceMesh",
|
||||
submesh_dim_names: tuple[str, ...],
|
||||
submesh_dims: List[tuple[int, ...]],
|
||||
submesh_dims: list[tuple[int, ...]],
|
||||
) -> "DeviceMesh":
|
||||
# Get the submesh dim size from the submesh_dims.
|
||||
# For example, if we have a 3D mesh with mesh_shape (2, 2, 2) mesh_dim_names ("dp", "cp", "tp") and we want
|
||||
|
|
@ -286,7 +286,7 @@ else:
|
|||
|
||||
def _get_slice_mesh_dims(
|
||||
self, device_mesh, mesh_dim_names
|
||||
) -> List[tuple[int, ...]]:
|
||||
) -> list[tuple[int, ...]]:
|
||||
"""
|
||||
Validate whether the mesh_dim_names is valid for slicing the given device_mesh.
|
||||
If valid, return dim indexes of the slice mesh in the device mesh.
|
||||
|
|
@ -338,7 +338,7 @@ else:
|
|||
|
||||
def _get_all_submeshes(
|
||||
self, device_mesh: "DeviceMesh", mesh_dim_name: str
|
||||
) -> List["DeviceMesh"]:
|
||||
) -> list["DeviceMesh"]:
|
||||
"""
|
||||
Return all the submeshes of a given mesh dimension of the device mesh.
|
||||
"""
|
||||
|
|
@ -458,7 +458,7 @@ else:
|
|||
# calculate the coordinates of the current global rank on the mesh
|
||||
rank_coords = (self.mesh == get_rank()).nonzero()
|
||||
assert rank_coords.size(0) in (0, 1)
|
||||
self._coordinate_on_dim: Optional[List[int]] = (
|
||||
self._coordinate_on_dim: Optional[list[int]] = (
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
|
||||
|
|
@ -498,7 +498,7 @@ else:
|
|||
# TODO(yifu): remove tag and ranks once we fully migrate to native
|
||||
# functional collectives. See details in:
|
||||
# https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
|
||||
dim_group_infos: List[tuple[str, List[int], str]] = []
|
||||
dim_group_infos: list[tuple[str, list[int], str]] = []
|
||||
default_group = _get_default_group()
|
||||
|
||||
if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
|
||||
|
|
@ -775,7 +775,7 @@ else:
|
|||
_find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) # type: ignore[index]
|
||||
)
|
||||
|
||||
def get_all_groups(self) -> List[ProcessGroup]:
|
||||
def get_all_groups(self) -> list[ProcessGroup]:
|
||||
"""
|
||||
Returns a list of ProcessGroups for all mesh dimensions.
|
||||
|
||||
|
|
@ -786,7 +786,7 @@ else:
|
|||
|
||||
@staticmethod
|
||||
def from_group(
|
||||
group: Union[ProcessGroup, List[ProcessGroup]],
|
||||
group: Union[ProcessGroup, list[ProcessGroup]],
|
||||
device_type: str,
|
||||
mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None,
|
||||
*,
|
||||
|
|
@ -910,7 +910,7 @@ else:
|
|||
), "We expect ProcessGroup before calling `get_rank`!"
|
||||
return not_none(get_rank(mesh_dim_group))
|
||||
|
||||
def get_coordinate(self) -> Optional[List[int]]:
|
||||
def get_coordinate(self) -> Optional[list[int]]:
|
||||
"""
|
||||
Return the relative indices of this rank relative to all
|
||||
dimensions of the mesh. If this rank is not part of the mesh, return None.
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import time
|
|||
import warnings
|
||||
from collections import namedtuple
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
|
@ -259,18 +259,18 @@ class Backend(str):
|
|||
|
||||
_BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"])
|
||||
|
||||
_plugins: Dict[str, _BackendPlugin] = {}
|
||||
_plugins: dict[str, _BackendPlugin] = {}
|
||||
|
||||
backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI]
|
||||
|
||||
# 3rd-party devices can register the default backend support here
|
||||
default_device_backend_map: Dict[str, str] = {
|
||||
default_device_backend_map: dict[str, str] = {
|
||||
"cpu": GLOO,
|
||||
"cuda": NCCL,
|
||||
"xpu": XCCL,
|
||||
}
|
||||
|
||||
backend_capability: Dict[str, List[str]] = {
|
||||
backend_capability: dict[str, list[str]] = {
|
||||
GLOO: ["cpu", "cuda"],
|
||||
NCCL: ["cuda"],
|
||||
XCCL: ["xpu"],
|
||||
|
|
@ -278,7 +278,7 @@ class Backend(str):
|
|||
MPI: ["cpu", "cuda"],
|
||||
}
|
||||
|
||||
backend_type_map: Dict[str, ProcessGroup.BackendType] = {
|
||||
backend_type_map: dict[str, ProcessGroup.BackendType] = {
|
||||
UNDEFINED: ProcessGroup.BackendType.UNDEFINED,
|
||||
GLOO: ProcessGroup.BackendType.GLOO,
|
||||
NCCL: ProcessGroup.BackendType.NCCL,
|
||||
|
|
@ -303,7 +303,7 @@ class Backend(str):
|
|||
name,
|
||||
func,
|
||||
extended_api=False,
|
||||
devices: Optional[Union[str, List[str]]] = None,
|
||||
devices: Optional[Union[str, list[str]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Register a new backend with the given name and instantiating function.
|
||||
|
|
@ -371,7 +371,7 @@ class BackendConfig:
|
|||
|
||||
def __init__(self, backend: Backend):
|
||||
"""Init."""
|
||||
self.device_backend_map: Dict[str, Backend] = {}
|
||||
self.device_backend_map: dict[str, Backend] = {}
|
||||
backend = str(backend)
|
||||
|
||||
if backend == Backend.UNDEFINED:
|
||||
|
|
@ -441,7 +441,7 @@ class BackendConfig:
|
|||
f"{device}:{backend}" for device, backend in self.device_backend_map.items()
|
||||
)
|
||||
|
||||
def get_device_backend_map(self) -> Dict[str, Backend]:
|
||||
def get_device_backend_map(self) -> dict[str, Backend]:
|
||||
"""Return backend map of the device."""
|
||||
return self.device_backend_map
|
||||
|
||||
|
|
@ -572,14 +572,14 @@ class _CollOp:
|
|||
|
||||
# DO NOT USE THESE FIELDS DIRECTLY.
|
||||
# Use them through the _world object to make sure the _world override mechanism
|
||||
_pg_map: Dict[ProcessGroup, tuple[str, Store]] = {}
|
||||
_pg_names: Dict[ProcessGroup, str] = {}
|
||||
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
|
||||
_pg_map: dict[ProcessGroup, tuple[str, Store]] = {}
|
||||
_pg_names: dict[ProcessGroup, str] = {}
|
||||
_pg_group_ranks: dict[ProcessGroup, dict[int, int]] = {}
|
||||
# For a pg, it is a map from ProcessGroup to BackendConfig
|
||||
_pg_backend_config: Dict[ProcessGroup, str] = {}
|
||||
_pg_backend_config: dict[ProcessGroup, str] = {}
|
||||
_group_count = 0
|
||||
_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
|
||||
_pg_to_tag: Dict[ProcessGroup, str] = {}
|
||||
_tags_to_pg: dict[str, list[ProcessGroup]] = {}
|
||||
_pg_to_tag: dict[ProcessGroup, str] = {}
|
||||
_backend: Optional[str] = None
|
||||
|
||||
|
||||
|
|
@ -595,7 +595,7 @@ class _World:
|
|||
|
||||
def __init__(self) -> None:
|
||||
self._default_pg = None
|
||||
self._pg_coalesce_state: Dict[ProcessGroup, List[_CollOp]] = {}
|
||||
self._pg_coalesce_state: dict[ProcessGroup, list[_CollOp]] = {}
|
||||
|
||||
@property
|
||||
def default_pg(self) -> Optional[ProcessGroup]:
|
||||
|
|
@ -612,7 +612,7 @@ class _World:
|
|||
self._default_pg = value
|
||||
|
||||
@property
|
||||
def pg_map(self) -> Dict[ProcessGroup, tuple[str, Store]]:
|
||||
def pg_map(self) -> dict[ProcessGroup, tuple[str, Store]]:
|
||||
"""
|
||||
Provide Mapping from ProcessGroup to backend name and store.
|
||||
|
||||
|
|
@ -625,7 +625,7 @@ class _World:
|
|||
return _pg_map
|
||||
|
||||
@property
|
||||
def pg_names(self) -> Dict[ProcessGroup, str]:
|
||||
def pg_names(self) -> dict[ProcessGroup, str]:
|
||||
"""
|
||||
Process group's names, map from ProcessGroup to str.
|
||||
|
||||
|
|
@ -635,7 +635,7 @@ class _World:
|
|||
return _pg_names
|
||||
|
||||
@property
|
||||
def pg_group_ranks(self) -> Dict[ProcessGroup, Dict[int, int]]:
|
||||
def pg_group_ranks(self) -> dict[ProcessGroup, dict[int, int]]:
|
||||
"""
|
||||
Process group's global rank to local rank mapping.
|
||||
|
||||
|
|
@ -645,7 +645,7 @@ class _World:
|
|||
return _pg_group_ranks
|
||||
|
||||
@property
|
||||
def pg_backend_config(self) -> Dict[ProcessGroup, str]:
|
||||
def pg_backend_config(self) -> dict[ProcessGroup, str]:
|
||||
"""
|
||||
Process group's backend config.
|
||||
|
||||
|
|
@ -671,27 +671,27 @@ class _World:
|
|||
_group_count = value
|
||||
|
||||
@property
|
||||
def tags_to_pg(self) -> Dict[str, List[ProcessGroup]]:
|
||||
def tags_to_pg(self) -> dict[str, list[ProcessGroup]]:
|
||||
global _tags_to_pg
|
||||
return _tags_to_pg
|
||||
|
||||
@property
|
||||
def pg_to_tag(self) -> Dict[ProcessGroup, str]:
|
||||
def pg_to_tag(self) -> dict[ProcessGroup, str]:
|
||||
global _pg_to_tag
|
||||
return _pg_to_tag
|
||||
|
||||
@property
|
||||
def pg_coalesce_state(self) -> Dict[ProcessGroup, List[_CollOp]]:
|
||||
def pg_coalesce_state(self) -> dict[ProcessGroup, list[_CollOp]]:
|
||||
return self._pg_coalesce_state
|
||||
|
||||
@property
|
||||
def pg_config_info(self) -> List[Dict[str, Any]]:
|
||||
def pg_config_info(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Return a list of dict with process groups and backends.
|
||||
|
||||
Along with their unique IDs and configurations (types and ranks).
|
||||
"""
|
||||
config_info: List[Dict[str, Any]] = []
|
||||
config_info: list[dict[str, Any]] = []
|
||||
default_pg_size = _get_group_size(None)
|
||||
for pg in self.pg_map.keys():
|
||||
ranks = self.pg_group_ranks[pg]
|
||||
|
|
@ -912,7 +912,7 @@ def _get_pg_default_device(group: Optional[ProcessGroup] = None) -> torch.device
|
|||
return rv
|
||||
|
||||
|
||||
def _device_capability(group: Optional[ProcessGroup] = None) -> List[str]:
|
||||
def _device_capability(group: Optional[ProcessGroup] = None) -> list[str]:
|
||||
"""
|
||||
Return the device type(s) supported by ``group``.
|
||||
|
||||
|
|
@ -1077,7 +1077,7 @@ def _get_global_rank(group, rank) -> int:
|
|||
return get_global_rank(group, rank)
|
||||
|
||||
|
||||
def get_process_group_ranks(group: ProcessGroup) -> List[int]:
|
||||
def get_process_group_ranks(group: ProcessGroup) -> list[int]:
|
||||
"""
|
||||
Get all ranks associated with ``group``.
|
||||
|
||||
|
|
@ -1103,7 +1103,7 @@ def _get_group_size_by_name(group_name: str) -> int:
|
|||
return group.size()
|
||||
|
||||
|
||||
def _resolve_group_name_by_ranks_and_tag(ranks: List[int], tag: str) -> str:
|
||||
def _resolve_group_name_by_ranks_and_tag(ranks: list[int], tag: str) -> str:
|
||||
# TODO(yifu): remove this function once ranks + tag is not a supported
|
||||
# identifier for process group for functional collectives.
|
||||
group = _find_pg_by_ranks_and_tag(tag, ranks)
|
||||
|
|
@ -1401,7 +1401,7 @@ def _get_process_group_uid(pg: ProcessGroup) -> int:
|
|||
return -1
|
||||
|
||||
|
||||
def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]:
|
||||
def _get_pg_config(group: Optional[ProcessGroup] = None) -> dict[str, Any]:
|
||||
"""
|
||||
Return the pg configuration of the given process group.
|
||||
|
||||
|
|
@ -1416,12 +1416,12 @@ def _get_pg_config(group: Optional[ProcessGroup] = None) -> Dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def _get_all_pg_configs() -> List[Dict[str, Any]]:
|
||||
def _get_all_pg_configs() -> list[dict[str, Any]]:
|
||||
"""
|
||||
Return the pg configuration of all the process groups.
|
||||
|
||||
"""
|
||||
config_info: List[Dict[str, Any]] = [
|
||||
config_info: list[dict[str, Any]] = [
|
||||
_get_pg_config(pg) for pg in _world.pg_map.keys()
|
||||
]
|
||||
return config_info
|
||||
|
|
@ -2507,7 +2507,7 @@ class _IllegalWork(Work):
|
|||
|
||||
class _CoalescingManager:
|
||||
def __init__(self) -> None:
|
||||
self.works: List[Work] = []
|
||||
self.works: list[Work] = []
|
||||
|
||||
def append(self, work: Work):
|
||||
if work:
|
||||
|
|
@ -2607,7 +2607,7 @@ def _coalescing_manager(
|
|||
work.wait() # type: ignore[possibly-undefined]
|
||||
|
||||
|
||||
def batch_isend_irecv(p2p_op_list: List[P2POp]) -> List[Work]:
|
||||
def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
|
||||
"""
|
||||
Send or Receive a batch of tensors asynchronously and return a list of requests.
|
||||
|
||||
|
|
@ -2651,7 +2651,7 @@ def batch_isend_irecv(p2p_op_list: List[P2POp]) -> List[Work]:
|
|||
group = p2p_op_list[0].group
|
||||
device = p2p_op_list[0].tensor.device
|
||||
|
||||
def peer_kwarg(op: P2POp) -> Dict[str, int]:
|
||||
def peer_kwarg(op: P2POp) -> dict[str, int]:
|
||||
key = "group_dst" if op.op == isend else "group_src"
|
||||
return {key: op.group_peer}
|
||||
|
||||
|
|
@ -3054,7 +3054,7 @@ def all_gather_object(object_list, obj, group=None):
|
|||
@_exception_logger
|
||||
def gather_object(
|
||||
obj: Any,
|
||||
object_gather_list: Optional[List[Any]] = None,
|
||||
object_gather_list: Optional[list[Any]] = None,
|
||||
dst: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group_dst: Optional[int] = None,
|
||||
|
|
@ -3178,7 +3178,7 @@ def gather_object(
|
|||
|
||||
@_exception_logger
|
||||
def send_object_list(
|
||||
object_list: List[Any],
|
||||
object_list: list[Any],
|
||||
dst: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
|
|
@ -3277,7 +3277,7 @@ def send_object_list(
|
|||
|
||||
@_exception_logger
|
||||
def recv_object_list(
|
||||
object_list: List[Any],
|
||||
object_list: list[Any],
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
|
|
@ -3379,7 +3379,7 @@ def recv_object_list(
|
|||
|
||||
@_exception_logger
|
||||
def broadcast_object_list(
|
||||
object_list: List[Any],
|
||||
object_list: list[Any],
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
|
|
@ -3505,8 +3505,8 @@ def broadcast_object_list(
|
|||
|
||||
@_exception_logger
|
||||
def scatter_object_list(
|
||||
scatter_object_output_list: List[Any],
|
||||
scatter_object_input_list: Optional[List[Any]] = None,
|
||||
scatter_object_output_list: list[Any],
|
||||
scatter_object_input_list: Optional[list[Any]] = None,
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
group_src: Optional[int] = None,
|
||||
|
|
@ -3933,7 +3933,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
|
|||
@_exception_logger
|
||||
def gather(
|
||||
tensor: torch.Tensor,
|
||||
gather_list: Optional[List[torch.Tensor]] = None,
|
||||
gather_list: Optional[list[torch.Tensor]] = None,
|
||||
dst: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
async_op: bool = False,
|
||||
|
|
@ -4013,7 +4013,7 @@ def gather(
|
|||
@_exception_logger
|
||||
def scatter(
|
||||
tensor: torch.Tensor,
|
||||
scatter_list: Optional[List[torch.Tensor]] = None,
|
||||
scatter_list: Optional[list[torch.Tensor]] = None,
|
||||
src: Optional[int] = None,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
async_op: bool = False,
|
||||
|
|
@ -4657,7 +4657,7 @@ def _create_process_group_wrapper(
|
|||
|
||||
# helper function for deterministically hashing a list of ranks to a unique
|
||||
# string
|
||||
def _hash_ranks_to_str(ranks: List[int]) -> str:
|
||||
def _hash_ranks_to_str(ranks: list[int]) -> str:
|
||||
rank_join: str = "_".join(map(str, ranks))
|
||||
# In case there is already a PG with the same rank composition
|
||||
unique_str = "_".join([rank_join, str(len(_world.pg_names))])
|
||||
|
|
@ -4665,7 +4665,7 @@ def _hash_ranks_to_str(ranks: List[int]) -> str:
|
|||
|
||||
|
||||
# Takes a list of ranks and computes an integer color
|
||||
def _process_group_color(ranks: List[int]) -> int:
|
||||
def _process_group_color(ranks: list[int]) -> int:
|
||||
# Convert list to tuple to make it hashable
|
||||
ranks = tuple(ranks)
|
||||
hash_value = hash(ranks)
|
||||
|
|
@ -5315,7 +5315,7 @@ def new_subgroups_by_enumeration(
|
|||
return cur_subgroup, subgroups
|
||||
|
||||
|
||||
def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGroup]:
|
||||
def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGroup]:
|
||||
if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"):
|
||||
tag = f"user:{tag}"
|
||||
|
||||
|
|
@ -5331,7 +5331,7 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> Optional[ProcessGro
|
|||
|
||||
|
||||
def _find_or_create_pg_by_ranks_and_tag(
|
||||
tag: str, ranks: List[int], stride: int
|
||||
tag: str, ranks: list[int], stride: int
|
||||
) -> ProcessGroup:
|
||||
assert (
|
||||
len(ranks) % stride == 0
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
import sys
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch.distributed.elastic.rendezvous.registry as rdzv_registry
|
||||
from torch.distributed.elastic import events, metrics
|
||||
|
|
@ -79,13 +79,13 @@ class LaunchConfig:
|
|||
role: str = "default_role"
|
||||
rdzv_endpoint: str = ""
|
||||
rdzv_backend: str = "etcd"
|
||||
rdzv_configs: Dict[str, Any] = field(default_factory=dict)
|
||||
rdzv_configs: dict[str, Any] = field(default_factory=dict)
|
||||
rdzv_timeout: int = -1
|
||||
max_restarts: int = 3
|
||||
monitor_interval: float = 0.1
|
||||
start_method: str = "spawn"
|
||||
log_line_prefix_template: Optional[str] = None
|
||||
metrics_cfg: Dict[str, str] = field(default_factory=dict)
|
||||
metrics_cfg: dict[str, str] = field(default_factory=dict)
|
||||
local_addr: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
|
|
@ -140,7 +140,7 @@ class elastic_launch:
|
|||
|
||||
|
||||
def _get_entrypoint_name(
|
||||
entrypoint: Union[Callable, str, None], args: List[Any]
|
||||
entrypoint: Union[Callable, str, None], args: list[Any]
|
||||
) -> str:
|
||||
"""Retrieve entrypoint name with the rule:
|
||||
1. If entrypoint is a function, use ``entrypoint.__qualname__``.
|
||||
|
|
@ -183,8 +183,8 @@ def _get_addr_and_port(
|
|||
def launch_agent(
|
||||
config: LaunchConfig,
|
||||
entrypoint: Union[Callable, str, None],
|
||||
args: List[Any],
|
||||
) -> Dict[int, Any]:
|
||||
args: list[Any],
|
||||
) -> dict[int, Any]:
|
||||
if not config.run_id:
|
||||
run_id = str(uuid.uuid4().int)
|
||||
logger.warning("config has no run_id, generated a random run_id: %s", run_id)
|
||||
|
|
|
|||
|
|
@ -7,11 +7,10 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
_log_handlers: Dict[str, logging.Handler] = {
|
||||
_log_handlers: dict[str, logging.Handler] = {
|
||||
"default": logging.NullHandler(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,20 +4,8 @@ import collections
|
|||
import io
|
||||
import sys
|
||||
import types
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
|
|
@ -109,8 +97,8 @@ def _create_module_with_interface(
|
|||
return rpc.RRef(module, module_interface_cls)
|
||||
|
||||
|
||||
def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]:
|
||||
ret: List[rpc.RRef[Parameter]] = [
|
||||
def _param_rrefs(module_rref, recurse) -> list[rpc.RRef[Parameter]]:
|
||||
ret: list[rpc.RRef[Parameter]] = [
|
||||
rpc.RRef(param) for param in module_rref.local_value().parameters(recurse)
|
||||
]
|
||||
return ret
|
||||
|
|
@ -129,9 +117,9 @@ class _RemoteModule(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
remote_device: str,
|
||||
module_cls: Type[nn.Module],
|
||||
args: Optional[Tuple] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
module_cls: type[nn.Module],
|
||||
args: Optional[tuple] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
_module_interface_cls: Any = None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -282,7 +270,7 @@ class _RemoteModule(nn.Module):
|
|||
self._install_generated_methods()
|
||||
self._check_attribute_picklability()
|
||||
|
||||
def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]:
|
||||
def remote_parameters(self, recurse: bool = True) -> list[rpc.RRef[Parameter]]:
|
||||
"""
|
||||
Return a list of :class:`~torch.distributed.rpc.RRef` pointing to the remote module's parameters.
|
||||
|
||||
|
|
@ -371,8 +359,8 @@ class _RemoteModule(nn.Module):
|
|||
hook: Union[
|
||||
Callable[[T, tuple[Any, ...]], Optional[Any]],
|
||||
Callable[
|
||||
[T, tuple[Any, ...], Dict[str, Any]],
|
||||
Optional[tuple[Any, Dict[str, Any]]],
|
||||
[T, tuple[Any, ...], dict[str, Any]],
|
||||
Optional[tuple[Any, dict[str, Any]]],
|
||||
],
|
||||
],
|
||||
prepend: bool = False,
|
||||
|
|
@ -384,7 +372,7 @@ class _RemoteModule(nn.Module):
|
|||
self,
|
||||
hook: Union[
|
||||
Callable[[T, tuple[Any, ...], Any], Optional[Any]],
|
||||
Callable[[T, tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
|
||||
Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]],
|
||||
],
|
||||
prepend: bool = False,
|
||||
with_kwargs: bool = False,
|
||||
|
|
@ -431,7 +419,7 @@ class _RemoteModule(nn.Module):
|
|||
|
||||
def named_modules(
|
||||
self,
|
||||
memo: Optional[Set[Module]] = None,
|
||||
memo: Optional[set[Module]] = None,
|
||||
prefix: str = "",
|
||||
remove_duplicate: bool = True,
|
||||
):
|
||||
|
|
@ -681,9 +669,9 @@ class RemoteModule(_RemoteModule):
|
|||
def __init__(
|
||||
self,
|
||||
remote_device: str,
|
||||
module_cls: Type[nn.Module],
|
||||
args: Optional[Tuple] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
module_cls: type[nn.Module],
|
||||
args: Optional[tuple] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(remote_device, module_cls, args, kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Any, Dict, Iterable, List, no_type_check, Type
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, no_type_check
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter
|
||||
# without changing it's life-time.
|
||||
|
|
@ -15,9 +16,9 @@ param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary()
|
|||
|
||||
@no_type_check
|
||||
def _apply_optimizer_in_backward(
|
||||
optimizer_class: Type[torch.optim.Optimizer],
|
||||
optimizer_class: type[torch.optim.Optimizer],
|
||||
params: Iterable[torch.nn.Parameter],
|
||||
optimizer_kwargs: Dict[str, Any],
|
||||
optimizer_kwargs: dict[str, Any],
|
||||
register_hook: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -97,7 +98,7 @@ def _apply_optimizer_in_backward(
|
|||
_apply_optimizer_in_backward_to_param(param)
|
||||
|
||||
|
||||
def _get_in_backward_optimizers(module: torch.nn.Module) -> List[torch.optim.Optimizer]:
|
||||
def _get_in_backward_optimizers(module: torch.nn.Module) -> list[torch.optim.Optimizer]:
|
||||
"""
|
||||
Return a list of in-backward optimizers applied to ``module``'s parameters. Note that these
|
||||
optimizers are not intended to directly have their ``step`` or ``zero_grad`` methods called
|
||||
|
|
@ -113,7 +114,7 @@ def _get_in_backward_optimizers(module: torch.nn.Module) -> List[torch.optim.Opt
|
|||
_apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {'lr': 0.01})
|
||||
optims = _get_optimizers_in_backward(model)
|
||||
"""
|
||||
optims: List[torch.optim.Optimizer] = []
|
||||
optims: list[torch.optim.Optimizer] = []
|
||||
for param in module.parameters():
|
||||
optims.extend(getattr(param, "_in_backward_optimizers", []))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
|
|
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
|
|||
)
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
# Define a TorchScript compatible Functional Adadelta Optimizer
|
||||
|
|
@ -25,7 +25,7 @@ __all__: List[str] = []
|
|||
class _FunctionalAdadelta:
|
||||
def __init__(
|
||||
self,
|
||||
params: List[Tensor],
|
||||
params: list[Tensor],
|
||||
lr: float = 1.0,
|
||||
rho: float = 0.9,
|
||||
eps: float = 1e-6,
|
||||
|
|
@ -51,9 +51,9 @@ class _FunctionalAdadelta:
|
|||
# param group as it's not a common use case.
|
||||
self.param_group = {"params": params}
|
||||
|
||||
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
|
||||
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
|
||||
|
||||
def step(self, gradients: List[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
|
|
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
|
|||
)
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
# Define a TorchScript compatible Functional Adagrad Optimizer
|
||||
|
|
@ -25,7 +25,7 @@ __all__: List[str] = []
|
|||
class _FunctionalAdagrad:
|
||||
def __init__(
|
||||
self,
|
||||
params: List[Tensor],
|
||||
params: list[Tensor],
|
||||
lr: float = 1e-2,
|
||||
lr_decay: float = 0.0,
|
||||
weight_decay: float = 0.0,
|
||||
|
|
@ -53,7 +53,7 @@ class _FunctionalAdagrad:
|
|||
self.foreach = foreach
|
||||
self.fused = fused
|
||||
self.maximize = maximize
|
||||
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
|
||||
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
|
||||
|
||||
if len(params) == 0 and not _allow_empty_param_list:
|
||||
raise ValueError("optimizer got an empty parameter list")
|
||||
|
|
@ -70,12 +70,12 @@ class _FunctionalAdagrad:
|
|||
"step": torch.tensor(0.0),
|
||||
}
|
||||
|
||||
def step(self, gradients: List[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
state_sums = []
|
||||
state_steps: List[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
|
||||
if len(params) != len(gradients):
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
|
|
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
|
|||
)
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
# Define a TorchScript compatible Functional Adam Optimizer
|
||||
|
|
@ -25,7 +25,7 @@ __all__: List[str] = []
|
|||
class _FunctionalAdam:
|
||||
def __init__(
|
||||
self,
|
||||
params: List[Tensor],
|
||||
params: list[Tensor],
|
||||
lr: float = 1e-3,
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
|
|
@ -59,7 +59,7 @@ class _FunctionalAdam:
|
|||
self.maximize = maximize
|
||||
self.foreach = foreach
|
||||
self.fused = fused
|
||||
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
|
||||
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
|
||||
|
||||
if len(params) == 0 and not _allow_empty_param_list:
|
||||
raise ValueError("optimizer got an empty parameter list")
|
||||
|
|
@ -78,7 +78,7 @@ class _FunctionalAdam:
|
|||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps: List[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
has_complex = torch.is_complex(param)
|
||||
if grad is not None:
|
||||
params_with_grad.append(param)
|
||||
|
|
@ -128,14 +128,14 @@ class _FunctionalAdam:
|
|||
found_inf=None,
|
||||
)
|
||||
|
||||
def step(self, gradients: List[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps: List[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
has_complex = False
|
||||
|
||||
if len(params) != len(gradients):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
|
|
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
|
|||
)
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
# Define a TorchScript compatible Functional Adamax Optimizer
|
||||
|
|
@ -25,7 +25,7 @@ __all__: List[str] = []
|
|||
class _FunctionalAdamax:
|
||||
def __init__(
|
||||
self,
|
||||
params: List[Tensor],
|
||||
params: list[Tensor],
|
||||
lr: float = 1e-3,
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
|
|
@ -55,7 +55,7 @@ class _FunctionalAdamax:
|
|||
}
|
||||
self.foreach = foreach
|
||||
self.maximize = maximize
|
||||
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
|
||||
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
|
||||
|
||||
if len(params) == 0 and not _allow_empty_param_list:
|
||||
raise ValueError("optimizer got an empty parameter list")
|
||||
|
|
@ -64,13 +64,13 @@ class _FunctionalAdamax:
|
|||
# param group as it's not a common use case.
|
||||
self.param_group = {"params": params}
|
||||
|
||||
def step(self, gradients: List[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_infs = []
|
||||
state_steps: List[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
|
||||
if len(params) != len(gradients):
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
|
|
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
|
|||
)
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
# Define a TorchScript compatible Functional AdamW Optimizer
|
||||
|
|
@ -25,7 +25,7 @@ __all__: List[str] = []
|
|||
class _FunctionalAdamW:
|
||||
def __init__(
|
||||
self,
|
||||
params: List[Tensor],
|
||||
params: list[Tensor],
|
||||
lr: float = 1e-3,
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
|
|
@ -59,7 +59,7 @@ class _FunctionalAdamW:
|
|||
self.maximize = maximize
|
||||
self.foreach = foreach
|
||||
self.fused = fused
|
||||
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
|
||||
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
|
||||
|
||||
if len(params) == 0 and not _allow_empty_param_list:
|
||||
raise ValueError("optimizer got an empty parameter list")
|
||||
|
|
@ -74,7 +74,7 @@ class _FunctionalAdamW:
|
|||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps: List[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
has_complex = torch.is_complex(param)
|
||||
if grad is not None:
|
||||
params_with_grad.append(param)
|
||||
|
|
@ -129,14 +129,14 @@ class _FunctionalAdamW:
|
|||
has_complex=has_complex,
|
||||
)
|
||||
|
||||
def step(self, gradients: List[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps: List[Tensor] = []
|
||||
state_steps: list[Tensor] = []
|
||||
|
||||
if len(params) != len(gradients):
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
|
|
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
|
|||
)
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
# Define a TorchScript compatible Functional RMSprop Optimizer
|
||||
|
|
@ -25,7 +25,7 @@ __all__: List[str] = []
|
|||
class _FunctionalRMSprop:
|
||||
def __init__(
|
||||
self,
|
||||
params: List[Tensor],
|
||||
params: list[Tensor],
|
||||
lr: float = 1e-2,
|
||||
alpha: float = 0.99,
|
||||
eps: float = 1e-8,
|
||||
|
|
@ -55,9 +55,9 @@ class _FunctionalRMSprop:
|
|||
# param group as it's not a common use case.
|
||||
self.param_group = {"params": params}
|
||||
|
||||
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
|
||||
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
|
||||
|
||||
def step(self, gradients: List[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
|
|
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
|
|||
)
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
# Define a TorchScript compatible Functional Rprop Optimizer
|
||||
|
|
@ -25,7 +25,7 @@ __all__: List[str] = []
|
|||
class _FunctionalRprop:
|
||||
def __init__(
|
||||
self,
|
||||
params: List[Tensor],
|
||||
params: list[Tensor],
|
||||
lr: float = 1e-2,
|
||||
etas: tuple[float, float] = (0.5, 1.2),
|
||||
step_sizes: tuple[float, float] = (1e-6, 50),
|
||||
|
|
@ -49,9 +49,9 @@ class _FunctionalRprop:
|
|||
# param group as it's not a common use case.
|
||||
self.param_group = {"params": params}
|
||||
|
||||
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
|
||||
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
|
||||
|
||||
def step(self, gradients: List[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.optim._functional as F
|
||||
|
|
@ -9,7 +9,7 @@ from torch.distributed.optim._deprecation_warning import (
|
|||
)
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
# Define a TorchScript compatible Functional SGD Optimizer
|
||||
|
|
@ -25,7 +25,7 @@ __all__: List[str] = []
|
|||
class _FunctionalSGD:
|
||||
def __init__(
|
||||
self,
|
||||
params: List[Tensor],
|
||||
params: list[Tensor],
|
||||
lr: float = 1e-2,
|
||||
momentum: float = 0.0,
|
||||
dampening: float = 0.0,
|
||||
|
|
@ -47,7 +47,7 @@ class _FunctionalSGD:
|
|||
self.maximize = maximize
|
||||
self.foreach = foreach
|
||||
self.fused = fused
|
||||
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
|
||||
self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
|
||||
|
||||
if len(params) == 0 and not _allow_empty_param_list:
|
||||
raise ValueError("optimizer got an empty parameter list")
|
||||
|
|
@ -67,7 +67,7 @@ class _FunctionalSGD:
|
|||
dampening = self.defaults["dampening"]
|
||||
lr = self.defaults["lr"]
|
||||
params = [param]
|
||||
momentum_buffer_list: List[Optional[Tensor]] = []
|
||||
momentum_buffer_list: list[Optional[Tensor]] = []
|
||||
grads = []
|
||||
|
||||
has_sparse_grad = False
|
||||
|
|
@ -106,11 +106,11 @@ class _FunctionalSGD:
|
|||
if momentum_buffer is not None:
|
||||
state["momentum_buffer"] = momentum_buffer
|
||||
|
||||
def step(self, gradients: List[Optional[Tensor]]):
|
||||
def step(self, gradients: list[Optional[Tensor]]):
|
||||
params = self.param_group["params"]
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
momentum_buffer_list: List[Optional[Tensor]] = []
|
||||
momentum_buffer_list: list[Optional[Tensor]] = []
|
||||
lr = self.defaults["lr"]
|
||||
weight_decay = self.defaults["weight_decay"]
|
||||
momentum = self.defaults["momentum"]
|
||||
|
|
|
|||
|
|
@ -1,18 +1,9 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import Collection, Mapping
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
overload,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -21,7 +12,7 @@ from torch.distributed._shard.sharded_tensor import ShardedTensor
|
|||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -129,7 +120,7 @@ class _NamedOptimizer(optim.Optimizer):
|
|||
)
|
||||
param_group["params"] = params
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Return the ``state_dict`` of the optimizer.
|
||||
|
||||
|
|
@ -317,7 +308,7 @@ class _NamedOptimizer(optim.Optimizer):
|
|||
# Calling ``step`` will load the initial state for optimizer states.
|
||||
self.step(closure=None)
|
||||
|
||||
def _pre_load_state_dict(self, state_dict) -> Dict[str, Any]:
|
||||
def _pre_load_state_dict(self, state_dict) -> dict[str, Any]:
|
||||
# TODO(chienchin): This API should be FSDP agnostic and should support
|
||||
# general user hooks.
|
||||
if isinstance(self.module, FSDP):
|
||||
|
|
@ -326,7 +317,7 @@ class _NamedOptimizer(optim.Optimizer):
|
|||
)
|
||||
return state_dict
|
||||
|
||||
def _post_state_dict(self, state_dict) -> Dict[str, Any]:
|
||||
def _post_state_dict(self, state_dict) -> dict[str, Any]:
|
||||
# TODO(chienchin): This API should be FSDP agnostic and should support
|
||||
# general user hooks.
|
||||
if isinstance(self.module, FSDP):
|
||||
|
|
@ -334,6 +325,6 @@ class _NamedOptimizer(optim.Optimizer):
|
|||
return state_dict
|
||||
|
||||
|
||||
def _gen_param_group_key(param_keys: List[str]) -> str:
|
||||
def _gen_param_group_key(param_keys: list[str]) -> str:
|
||||
"""Concatenate all param keys as a unique indentifier for one param group."""
|
||||
return "/".join(sorted(param_keys))
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import logging
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
|
|
@ -52,7 +52,7 @@ class _ScriptLocalOptimizer(nn.Module):
|
|||
def step(self, autograd_ctx_id: int):
|
||||
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
|
||||
# apply functional optimizer step with a list of gradients
|
||||
grads: List[Optional[Tensor]] = [
|
||||
grads: list[Optional[Tensor]] = [
|
||||
all_local_grads[p] if p in all_local_grads else None
|
||||
for p in self._local_params
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Type
|
||||
|
||||
from torch import optim
|
||||
|
||||
|
|
@ -46,7 +45,7 @@ def register_functional_optim(key, optim):
|
|||
functional_optim_map[key] = optim
|
||||
|
||||
|
||||
def as_functional_optim(optim_cls: Type, *args, **kwargs):
|
||||
def as_functional_optim(optim_cls: type, *args, **kwargs):
|
||||
try:
|
||||
functional_cls = functional_optim_map[optim_cls]
|
||||
except KeyError as e:
|
||||
|
|
@ -57,7 +56,7 @@ def as_functional_optim(optim_cls: Type, *args, **kwargs):
|
|||
return _create_functional_optim(functional_cls, *args, **kwargs)
|
||||
|
||||
|
||||
def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
|
||||
def _create_functional_optim(functional_optim_cls: type, *args, **kwargs):
|
||||
return functional_optim_cls(
|
||||
[],
|
||||
*args,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import inspect
|
|||
import io
|
||||
import logging
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -156,7 +156,7 @@ class _DDPBucketAssignment:
|
|||
def __init__(
|
||||
self,
|
||||
bucket_index: int,
|
||||
parameters: List[torch.Tensor],
|
||||
parameters: list[torch.Tensor],
|
||||
offset: int,
|
||||
):
|
||||
self.bucket_index = bucket_index
|
||||
|
|
@ -239,20 +239,20 @@ class _OverlapInfo:
|
|||
self.shard_buckets: bool = False
|
||||
|
||||
# Modified per bucket reconstruction
|
||||
self.params_per_bucket: List[List[torch.Tensor]] = []
|
||||
self.params_per_rank: List[List[torch.Tensor]] = [[] for _ in range(world_size)]
|
||||
self.offsets: Dict[int, int] = {}
|
||||
self.params_per_bucket: list[list[torch.Tensor]] = []
|
||||
self.params_per_rank: list[list[torch.Tensor]] = [[] for _ in range(world_size)]
|
||||
self.offsets: dict[int, int] = {}
|
||||
# Group Ranks
|
||||
self.assigned_ranks_per_bucket: List[Set[int]] = []
|
||||
self.assigned_ranks_per_bucket: list[set[int]] = []
|
||||
self.num_bucket_assignments: int = 0
|
||||
self.total_size: Optional[int] = None
|
||||
|
||||
# Modified per iteration
|
||||
self.broadcast_handles: List[Any] = []
|
||||
self.bucket_indices_seen: List[int] = []
|
||||
self.broadcast_handles: list[Any] = []
|
||||
self.bucket_indices_seen: list[int] = []
|
||||
# Used by `hook_with_zero_step()`
|
||||
self.bucket_index_to_future: Dict[int, torch.futures.Future] = {}
|
||||
self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {}
|
||||
self.bucket_index_to_future: dict[int, torch.futures.Future] = {}
|
||||
self.bucket_index_to_bucket: dict[int, dist.GradBucket] = {}
|
||||
|
||||
def wait_for_broadcasts(self) -> None:
|
||||
r"""
|
||||
|
|
@ -372,7 +372,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
def __init__(
|
||||
self,
|
||||
params,
|
||||
optimizer_class: Type[Optimizer],
|
||||
optimizer_class: type[Optimizer],
|
||||
process_group: Optional[Any] = None,
|
||||
parameters_as_bucket_view: bool = False,
|
||||
overlap_with_ddp: bool = False,
|
||||
|
|
@ -395,15 +395,15 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
# `self.param_groups`
|
||||
|
||||
# Internal data structures (`_cache` indicates lazily evaluated)
|
||||
self._param_to_rank_cache: Dict[torch.Tensor, int] = {}
|
||||
self._param_to_index_cache: Dict[torch.Tensor, int] = {}
|
||||
self._partition_parameters_cache: List[List[Dict]] = []
|
||||
self._index_to_param_cache: List[torch.Tensor] = []
|
||||
self._device_to_params_per_rank_cache: Dict[
|
||||
torch.device, List[List[torch.Tensor]]
|
||||
self._param_to_rank_cache: dict[torch.Tensor, int] = {}
|
||||
self._param_to_index_cache: dict[torch.Tensor, int] = {}
|
||||
self._partition_parameters_cache: list[list[dict]] = []
|
||||
self._index_to_param_cache: list[torch.Tensor] = []
|
||||
self._device_to_params_per_rank_cache: dict[
|
||||
torch.device, list[list[torch.Tensor]]
|
||||
] = {}
|
||||
self._bucket_assignments_per_rank_cache: List[
|
||||
Dict[int, _DDPBucketAssignment]
|
||||
self._bucket_assignments_per_rank_cache: list[
|
||||
dict[int, _DDPBucketAssignment]
|
||||
] = []
|
||||
self._is_trainable_mask = self._get_is_trainable_mask()
|
||||
|
||||
|
|
@ -439,12 +439,12 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
# `self._buckets` is used if `parameters_as_bucket_view=True`, in
|
||||
# which case parameter data is flattened into contiguous bucket tensors
|
||||
self.parameters_as_bucket_view = parameters_as_bucket_view
|
||||
self._buckets: List[List[torch.Tensor]] = []
|
||||
self._buckets: list[list[torch.Tensor]] = []
|
||||
self._build_param_buckets()
|
||||
|
||||
# Optional consolidated optimizer state, only populated if this rank
|
||||
# is the target in `consolidate_state_dict()`
|
||||
self._all_state_dicts: List[Dict[str, Any]] = []
|
||||
self._all_state_dicts: list[dict[str, Any]] = []
|
||||
|
||||
self.initialized = True
|
||||
|
||||
|
|
@ -457,7 +457,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
self._device_to_params_per_rank_cache.clear()
|
||||
self._bucket_assignments_per_rank_cache.clear()
|
||||
|
||||
def add_param_group(self, param_group: Dict[str, Any]) -> None:
|
||||
def add_param_group(self, param_group: dict[str, Any]) -> None:
|
||||
r"""
|
||||
Add a parameter group to the :class:`Optimizer` 's ``param_groups``.
|
||||
|
||||
|
|
@ -586,7 +586,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
|
||||
def _verify_params_per_rank(
|
||||
self,
|
||||
params_per_rank: List[List[torch.Tensor]],
|
||||
params_per_rank: list[list[torch.Tensor]],
|
||||
) -> None:
|
||||
r"""
|
||||
Verify ``params_per_rank`` for :meth:`_partition_parameters`.
|
||||
|
|
@ -619,7 +619,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
)
|
||||
|
||||
def _partition_param_group(
|
||||
self, param_group: Dict[str, Any], params_per_rank: List[List[torch.Tensor]]
|
||||
self, param_group: dict[str, Any], params_per_rank: list[list[torch.Tensor]]
|
||||
) -> None:
|
||||
r"""
|
||||
Partition the parameter group ``param_group`` according to ``params_per_rank``.
|
||||
|
|
@ -641,8 +641,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
|
||||
def _partition_parameters(
|
||||
self,
|
||||
params_per_rank: Optional[List[List[torch.Tensor]]] = None,
|
||||
) -> List[List[Dict]]:
|
||||
params_per_rank: Optional[list[list[torch.Tensor]]] = None,
|
||||
) -> list[list[dict]]:
|
||||
r"""
|
||||
Partitions parameters across distributed data parallel ranks.
|
||||
|
||||
|
|
@ -674,7 +674,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
self._partition_parameters_cache = [[] for _ in range(self.world_size)]
|
||||
sizes = [0] * self.world_size
|
||||
for param_group in self.param_groups:
|
||||
param_group_params_per_rank: List[List] = [
|
||||
param_group_params_per_rank: list[list] = [
|
||||
[] for _ in range(self.world_size)
|
||||
]
|
||||
# Sort the parameters by size (largest first)
|
||||
|
|
@ -712,7 +712,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
return self._partition_parameters_cache
|
||||
|
||||
@property
|
||||
def _param_to_rank(self) -> Dict[torch.Tensor, int]:
|
||||
def _param_to_rank(self) -> dict[torch.Tensor, int]:
|
||||
r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition."""
|
||||
if len(self._param_to_rank_cache) == 0:
|
||||
for rank, param_groups in enumerate(self._partition_parameters()):
|
||||
|
|
@ -722,7 +722,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
return self._param_to_rank_cache
|
||||
|
||||
@property
|
||||
def _param_to_index(self) -> Dict[torch.Tensor, int]:
|
||||
def _param_to_index(self) -> dict[torch.Tensor, int]:
|
||||
r"""
|
||||
:class:`dict` mapping parameters to their indices in the global optimizer state.
|
||||
|
||||
|
|
@ -737,7 +737,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
return self._param_to_index_cache
|
||||
|
||||
@property
|
||||
def _index_to_param(self) -> List[torch.Tensor]:
|
||||
def _index_to_param(self) -> list[torch.Tensor]:
|
||||
r"""List mapping parameter indices in the global optimizer scheme to the actual params."""
|
||||
if len(self._index_to_param_cache) == 0:
|
||||
self._index_to_param_cache = list(
|
||||
|
|
@ -811,7 +811,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
@property
|
||||
def _device_to_params_per_rank(
|
||||
self,
|
||||
) -> Dict[torch.device, List[List[torch.Tensor]]]:
|
||||
) -> dict[torch.device, list[list[torch.Tensor]]]:
|
||||
r"""
|
||||
Return device parameters assigned per rank.
|
||||
|
||||
|
|
@ -854,8 +854,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
|
||||
def _get_min_index(
|
||||
self,
|
||||
values: List[int],
|
||||
disallowed_indices: Optional[Set[int]] = None,
|
||||
values: list[int],
|
||||
disallowed_indices: Optional[set[int]] = None,
|
||||
) -> int:
|
||||
r"""
|
||||
Return ``values.index(min(values))``, except only uses one pass.
|
||||
|
|
@ -881,10 +881,10 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
def _assign_bucket_subset_to_rank(
|
||||
self,
|
||||
bucket_index: int,
|
||||
bucket_params: List[torch.Tensor],
|
||||
bucket_params: list[torch.Tensor],
|
||||
bucket_offset: int,
|
||||
assigned_rank: int,
|
||||
assigned_ranks_per_bucket: List[Set[int]],
|
||||
assigned_ranks_per_bucket: list[set[int]],
|
||||
) -> None:
|
||||
r"""
|
||||
Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information.
|
||||
|
|
@ -919,7 +919,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
self._overlap_info.num_bucket_assignments += 1
|
||||
|
||||
@property
|
||||
def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]:
|
||||
def _bucket_assignments_per_rank(self) -> list[dict[int, _DDPBucketAssignment]]:
|
||||
r"""
|
||||
Return DDP bucket parameters assigned per rank.
|
||||
|
||||
|
|
@ -1015,7 +1015,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
|
||||
def _local_step(
|
||||
self,
|
||||
gradients: Optional[List[Optional[torch.Tensor]]] = None,
|
||||
gradients: Optional[list[Optional[torch.Tensor]]] = None,
|
||||
closure: Optional[Callable[[], float]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Optional[float]:
|
||||
|
|
@ -1148,7 +1148,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
r"""Return process group."""
|
||||
return self.process_group
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
||||
r"""
|
||||
Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed.
|
||||
|
||||
|
|
@ -1186,7 +1186,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
self._sync_param_groups(state_dict["param_groups"], self.param_groups)
|
||||
self._sync_param_groups(self.param_groups, self.optim.param_groups)
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
r"""
|
||||
Return the last global optimizer state known to this rank.
|
||||
|
||||
|
|
@ -1252,8 +1252,8 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
|
||||
@staticmethod
|
||||
def _sync_param_groups(
|
||||
src_param_groups: List[Dict[Any, Any]],
|
||||
dst_param_groups: List[Dict[Any, Any]],
|
||||
src_param_groups: list[dict[Any, Any]],
|
||||
dst_param_groups: list[dict[Any, Any]],
|
||||
) -> None:
|
||||
r"""
|
||||
Sync the attributes from the source parameter groups to the destination parameter groups.
|
||||
|
|
@ -1381,7 +1381,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
def _verify_and_init_params(
|
||||
self,
|
||||
params: Any,
|
||||
) -> Union[List[torch.Tensor], List[dict]]:
|
||||
) -> Union[list[torch.Tensor], list[dict]]:
|
||||
r"""
|
||||
Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters.
|
||||
|
||||
|
|
@ -1469,7 +1469,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|||
f"{other_typename}"
|
||||
)
|
||||
|
||||
def _get_is_trainable_mask(self) -> List[bool]:
|
||||
def _get_is_trainable_mask(self) -> list[bool]:
|
||||
r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not."""
|
||||
return list(map(_is_trainable, self._all_params))
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from collections import defaultdict
|
|||
from enum import Enum
|
||||
from inspect import Parameter, Signature, signature
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
|
@ -116,7 +116,7 @@ def _insert_stage_symbolic_backward(
|
|||
output_node: fx.Node,
|
||||
):
|
||||
# Collect metadata about tuple output values. TODO: move this to split_module or FX IR
|
||||
tuples: Dict[fx.Node, Tuple] = {}
|
||||
tuples: dict[fx.Node, tuple] = {}
|
||||
for node in reversed(g.nodes):
|
||||
if node.op == "call_function":
|
||||
# In the forward pass, only emit placeholder, module calls, and
|
||||
|
|
@ -155,7 +155,7 @@ def _insert_stage_symbolic_backward(
|
|||
# We will only emit backward operations for nodes that can contribute
|
||||
# to the specified loss value.
|
||||
live_nodes = {loss_node: None}
|
||||
val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
|
||||
val_to_grad: dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
|
||||
|
||||
def assign_or_accumulate_grad(forward_node, grad_value):
|
||||
if forward_node in val_to_grad and forward_node.op != "placeholder":
|
||||
|
|
@ -349,7 +349,7 @@ class MultiUseParameterConfig(Enum):
|
|||
REPLICATE = 2
|
||||
|
||||
|
||||
MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]]
|
||||
MultiUseParamSpec = Union[MultiUseParameterConfig, dict[str, MultiUseParameterConfig]]
|
||||
|
||||
|
||||
class DetachExecutor(fx.Interpreter):
|
||||
|
|
@ -432,7 +432,7 @@ class _LinearNodeList:
|
|||
def to_graph(self):
|
||||
graph = fx.Graph()
|
||||
|
||||
ref_str_to_node: Dict[str, fx.Node] = {}
|
||||
ref_str_to_node: dict[str, fx.Node] = {}
|
||||
|
||||
def ref_to_node(arg):
|
||||
if isinstance(arg, _NodeReference):
|
||||
|
|
@ -557,14 +557,14 @@ class Pipe(torch.nn.Module):
|
|||
|
||||
# Map parameter value to a dictionary that maps the user pipeline module
|
||||
# to the local qualname within that module
|
||||
params_to_users: Dict[torch.nn.Parameter, Dict[str, str]] = {}
|
||||
params_to_users: dict[torch.nn.Parameter, dict[str, str]] = {}
|
||||
|
||||
for m_qualname, mod in self.split_gm.named_children():
|
||||
for p_qualname, param in mod.named_parameters():
|
||||
params_to_users.setdefault(param, {})
|
||||
params_to_users[param][m_qualname] = p_qualname
|
||||
|
||||
self.replicated_params: List[Dict[str, str]] = [
|
||||
self.replicated_params: list[dict[str, str]] = [
|
||||
use_mapping
|
||||
for _, use_mapping in params_to_users.items()
|
||||
if len(use_mapping) > 1
|
||||
|
|
@ -645,7 +645,7 @@ class Pipe(torch.nn.Module):
|
|||
@staticmethod
|
||||
def _number_and_count_forward_stages(gm: fx.GraphModule):
|
||||
num_stages = 0
|
||||
found_idxs: Dict[int, None] = {}
|
||||
found_idxs: dict[int, None] = {}
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_module" and node.target.startswith("submod_"):
|
||||
node.meta["stage_idx"] = int(node.target[len("submod_") :])
|
||||
|
|
@ -693,7 +693,7 @@ class Pipe(torch.nn.Module):
|
|||
# Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving
|
||||
# parameters relies on the invariant that parameter accesses happen once. This is not necessarily
|
||||
# the case (especially with custom tracers), so fix that up here.
|
||||
get_attr_nodes: Dict[str, fx.Node] = {}
|
||||
get_attr_nodes: dict[str, fx.Node] = {}
|
||||
for node in traced.graph.nodes: # type: ignore[union-attr]
|
||||
if node.op == "get_attr":
|
||||
get_attr_nodes.setdefault(node.target, node)
|
||||
|
|
@ -868,7 +868,7 @@ class Pipe(torch.nn.Module):
|
|||
|
||||
# [aliasing] store tensor id -> list of FQNs, built from state dict
|
||||
# Also assign non-persistent buffers
|
||||
id_to_fqns: Dict[int, Set[str]] = defaultdict(set)
|
||||
id_to_fqns: dict[int, set[str]] = defaultdict(set)
|
||||
for fqn, tensor in mod.state_dict(keep_vars=True).items():
|
||||
id_to_fqns[id(tensor)].add(fqn)
|
||||
for fqn, tensor in mod.named_buffers():
|
||||
|
|
@ -878,7 +878,7 @@ class Pipe(torch.nn.Module):
|
|||
# need to move the `get_attr` nodes from the root of the graph to those
|
||||
# hierarchies.
|
||||
# [aliasing] use id -> fqn mapping to list out all valid FQNs
|
||||
inputs_to_state: Dict[str, List[str]] = {}
|
||||
inputs_to_state: dict[str, list[str]] = {}
|
||||
for attr in attr_nodes:
|
||||
_, tensor = _recursive_getattr_with_parent(mod, attr.target)
|
||||
fqns = list(id_to_fqns[id(tensor)])
|
||||
|
|
@ -890,7 +890,7 @@ class Pipe(torch.nn.Module):
|
|||
# [aliasing] for each submodule split, assign attributes on FQNs that may be used.
|
||||
# We determine this based on whether or not the FQN attribute parent exists.
|
||||
# i.e. if the last submodule exists, assign the attribute.
|
||||
added_attributes: Dict[str, List[str]] = defaultdict(list)
|
||||
added_attributes: dict[str, list[str]] = defaultdict(list)
|
||||
for fqn, tensor in mod.state_dict(keep_vars=True).items():
|
||||
for name, submod in split.named_children():
|
||||
if isinstance(submod, fx.GraphModule):
|
||||
|
|
@ -999,7 +999,7 @@ class Pipe(torch.nn.Module):
|
|||
def _trace_with_export(
|
||||
mod: torch.nn.Module,
|
||||
example_args: tuple[Any, ...],
|
||||
example_kwargs: Optional[Dict[str, Any]] = None,
|
||||
example_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> ExportedProgram:
|
||||
logger.info("Tracing model ...")
|
||||
try:
|
||||
|
|
@ -1023,7 +1023,7 @@ class Pipe(torch.nn.Module):
|
|||
def from_tracing(
|
||||
mod: torch.nn.Module,
|
||||
example_args: tuple[Any, ...],
|
||||
example_kwargs: Optional[Dict[str, Any]] = None,
|
||||
example_kwargs: Optional[dict[str, Any]] = None,
|
||||
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
|
||||
):
|
||||
# If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
|
||||
|
|
@ -1173,7 +1173,7 @@ def _split_after_forward(self, *args, **kwargs):
|
|||
pipe_split()
|
||||
|
||||
|
||||
def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
|
||||
def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]):
|
||||
# TODO: make this implementation out-of-place?
|
||||
for qualname, split_type in spec.items():
|
||||
atoms = qualname.split(".")
|
||||
|
|
@ -1200,8 +1200,8 @@ def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
|
|||
def pipeline(
|
||||
module: torch.nn.Module,
|
||||
mb_args: tuple[Any, ...],
|
||||
mb_kwargs: Optional[Dict[str, Any]] = None,
|
||||
split_spec: Optional[Dict[str, SplitPoint]] = None,
|
||||
mb_kwargs: Optional[dict[str, Any]] = None,
|
||||
split_spec: Optional[dict[str, SplitPoint]] = None,
|
||||
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
|
||||
) -> Pipe:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import collections
|
||||
import logging
|
||||
from typing import Any, Deque, Dict, Iterator, List, Optional, Set, Union
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.autograd.graph import GradientEdge, Node
|
||||
|
|
@ -37,8 +38,8 @@ def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]:
|
|||
|
||||
|
||||
def reverse_closure(
|
||||
roots: List[Node], target_nodes: Set[Node], reverse_edges_dict
|
||||
) -> tuple[Set[Node], Set[Node]]:
|
||||
roots: list[Node], target_nodes: set[Node], reverse_edges_dict
|
||||
) -> tuple[set[Node], set[Node]]:
|
||||
"""
|
||||
This function returns the reverse closure of the given roots,
|
||||
i.e. the set of nodes that can be reached from the roots by following the
|
||||
|
|
@ -46,9 +47,9 @@ def reverse_closure(
|
|||
include in the closure.
|
||||
"""
|
||||
# Recurse until we reach a target node
|
||||
closure: Set[Node] = set()
|
||||
closure: set[Node] = set()
|
||||
visited_target_nodes = set()
|
||||
q: Deque[Node] = collections.deque()
|
||||
q: collections.deque[Node] = collections.deque()
|
||||
for node in roots:
|
||||
if node is not None and node not in closure:
|
||||
closure.add(node)
|
||||
|
|
@ -67,10 +68,10 @@ def reverse_closure(
|
|||
return closure, visited_target_nodes
|
||||
|
||||
|
||||
def construct_reverse_graph(roots: List[Node]) -> Dict[Node, List[Node]]:
|
||||
q: Deque[Node] = collections.deque()
|
||||
root_seen: Set[Node] = set()
|
||||
reverse_edges_dict: Dict[Node, List[Node]] = collections.defaultdict(list)
|
||||
def construct_reverse_graph(roots: list[Node]) -> dict[Node, list[Node]]:
|
||||
q: collections.deque[Node] = collections.deque()
|
||||
root_seen: set[Node] = set()
|
||||
reverse_edges_dict: dict[Node, list[Node]] = collections.defaultdict(list)
|
||||
for node in roots:
|
||||
if node is not None and node not in root_seen:
|
||||
q.append(node)
|
||||
|
|
@ -86,8 +87,8 @@ def construct_reverse_graph(roots: List[Node]) -> Dict[Node, List[Node]]:
|
|||
|
||||
|
||||
def get_param_groups(
|
||||
inputs: List[Node], params: List[Node], reverse_edges_dict
|
||||
) -> List[Dict[str, Any]]:
|
||||
inputs: list[Node], params: list[Node], reverse_edges_dict
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Given a list of inputs and a list of parameters, return a list of parameter
|
||||
groups, where each group contains the parameters and the intermediates that
|
||||
|
|
@ -103,12 +104,12 @@ def get_param_groups(
|
|||
# reverse graph that starts with inputs, and goes up to the dOutput or the loss,
|
||||
# but omits weights and any subgraphs connecting weights to this closure
|
||||
inputs_closure, _ = reverse_closure(inputs, set(), reverse_edges_dict)
|
||||
param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates
|
||||
param_groups: dict[Node, dict[str, set]] = dict() # keyed on intermediates
|
||||
for param in params:
|
||||
closure, intersected = reverse_closure(
|
||||
[param], inputs_closure, reverse_edges_dict
|
||||
)
|
||||
param_group: Dict[str, Set] = {
|
||||
param_group: dict[str, set] = {
|
||||
"params": {param},
|
||||
"intermediates": intersected,
|
||||
}
|
||||
|
|
@ -124,8 +125,8 @@ def get_param_groups(
|
|||
param_groups[input_node] = param_group
|
||||
|
||||
# Sanity check: union of all param_groups params should be equal to all params
|
||||
union_params: Set[Node] = set()
|
||||
seen_ids: Set[int] = set()
|
||||
union_params: set[Node] = set()
|
||||
seen_ids: set[int] = set()
|
||||
unique_param_groups = []
|
||||
for param_group in param_groups.values():
|
||||
if id(param_group) not in seen_ids:
|
||||
|
|
@ -140,11 +141,11 @@ def get_param_groups(
|
|||
|
||||
|
||||
def stage_backward_input(
|
||||
stage_outputs_or_loss: List[torch.Tensor],
|
||||
output_grads: Optional[List[torch.Tensor]],
|
||||
input_values: List[torch.Tensor],
|
||||
stage_outputs_or_loss: list[torch.Tensor],
|
||||
output_grads: Optional[list[torch.Tensor]],
|
||||
input_values: list[torch.Tensor],
|
||||
weights: Iterator[Parameter],
|
||||
) -> tuple[tuple[Optional[torch.Tensor], ...], List[Dict[str, Any]]]:
|
||||
) -> tuple[tuple[Optional[torch.Tensor], ...], list[dict[str, Any]]]:
|
||||
"""
|
||||
Compute the gradients for only the stage inputs with
|
||||
respect to the stage outputs (if non-last stage) or loss (if last stage)
|
||||
|
|
@ -155,13 +156,13 @@ def stage_backward_input(
|
|||
Detaching the stage_outputs_or_loss at the end of this function is important as
|
||||
it frees up the memory that the autograd graph is anticipating to be used later (but doesn't actually need).
|
||||
"""
|
||||
stage_output_grad_fns: List[Node] = list(
|
||||
stage_output_grad_fns: list[Node] = list(
|
||||
filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs_or_loss))
|
||||
)
|
||||
stage_input_grad_fns: List[Node] = list(
|
||||
stage_input_grad_fns: list[Node] = list(
|
||||
filter(None, map(_get_grad_fn_or_grad_acc, input_values))
|
||||
)
|
||||
weight_grad_fns: List[Node] = list(
|
||||
weight_grad_fns: list[Node] = list(
|
||||
filter(None, map(_get_grad_fn_or_grad_acc, weights))
|
||||
)
|
||||
|
||||
|
|
@ -222,11 +223,11 @@ def stage_backward_input(
|
|||
|
||||
|
||||
def stage_backward_weight(
|
||||
weights: Iterator[Parameter], param_groups: List[Dict[str, Any]], retain_graph=False
|
||||
weights: Iterator[Parameter], param_groups: list[dict[str, Any]], retain_graph=False
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
# map weights to param_group_weights
|
||||
grad_acc_to_weight = {}
|
||||
weight_grads: List[Optional[torch.Tensor]] = []
|
||||
weight_grads: list[Optional[torch.Tensor]] = []
|
||||
for index, weight in enumerate(weights):
|
||||
grad_acc = _get_grad_fn_or_grad_acc(weight)
|
||||
grad_acc_to_weight[grad_acc] = weight, index
|
||||
|
|
@ -273,7 +274,7 @@ def stage_backward(
|
|||
stage_output,
|
||||
output_grads,
|
||||
input_values,
|
||||
outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used
|
||||
outputs_with_grads_idxs: Optional[list[int]] = None, # deprecated, not used
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
"""
|
||||
This is a helper function to:
|
||||
|
|
@ -293,8 +294,8 @@ def stage_backward(
|
|||
try:
|
||||
# stage_output may be a composite datatype like dict. Extract all individual
|
||||
# tensor values here
|
||||
stage_output_tensors: List[torch.Tensor] = []
|
||||
output_grad_tensors: List[Optional[torch.Tensor]] = []
|
||||
stage_output_tensors: list[torch.Tensor] = []
|
||||
output_grad_tensors: list[Optional[torch.Tensor]] = []
|
||||
|
||||
def extract_tensors_with_grads(
|
||||
output_val,
|
||||
|
|
@ -353,7 +354,7 @@ def stage_backward(
|
|||
)
|
||||
|
||||
# Extract gradients wrt the input values
|
||||
grad_inputs: List[Optional[torch.Tensor]] = []
|
||||
grad_inputs: list[Optional[torch.Tensor]] = []
|
||||
for val in input_values:
|
||||
if isinstance(val, torch.Tensor):
|
||||
grad_inputs.append(val.grad)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set
|
||||
|
||||
import torch
|
||||
from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry
|
||||
|
|
@ -9,10 +8,10 @@ from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry
|
|||
def _outline_submodules(orig_graph: torch.fx.Graph) -> torch.fx.GraphModule:
|
||||
# Create an empty GraphModule to hold the outlined modules
|
||||
new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
|
||||
seen_nodes: Dict[str, torch.fx.Node] = {}
|
||||
seen_modules: Dict[int, List[_SubmoduleEntry]] = defaultdict(list)
|
||||
seen_attrs: Dict[str, Set[str]] = defaultdict(set)
|
||||
created_modules: Dict[str, torch.nn.Module] = {}
|
||||
seen_nodes: dict[str, torch.fx.Node] = {}
|
||||
seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list)
|
||||
seen_attrs: dict[str, set[str]] = defaultdict(set)
|
||||
created_modules: dict[str, torch.nn.Module] = {}
|
||||
_ModuleFrame(
|
||||
orig_graph,
|
||||
tuple(orig_graph.nodes),
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
|
|
@ -75,8 +75,8 @@ def validate_tensor_metadata(desc, expected, given):
|
|||
|
||||
def validate_tensors_metadata(
|
||||
desc,
|
||||
expected_tensors: Union[List[torch.Tensor], tuple[torch.Tensor, ...]],
|
||||
actual_tensors: Union[List[torch.Tensor], tuple[torch.Tensor, ...]],
|
||||
expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
|
||||
actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
|
||||
):
|
||||
if len(expected_tensors) != len(actual_tensors):
|
||||
raise PipeliningShapeError(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.fx.node import map_aggregate
|
||||
|
|
@ -92,7 +92,7 @@ class TensorChunkSpec:
|
|||
|
||||
@staticmethod
|
||||
def from_dict(
|
||||
chunk_dims: Dict[str, int],
|
||||
chunk_dims: dict[str, int],
|
||||
):
|
||||
"""
|
||||
A helper for creating a dictionary of `TensorChunkSpec` from a
|
||||
|
|
@ -243,11 +243,11 @@ def _shard_dict_of_args(
|
|||
|
||||
def split_args_kwargs_into_chunks(
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]],
|
||||
kwargs: Optional[dict[str, Any]],
|
||||
chunks: int,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
) -> tuple[List[Tuple], List[Dict]]:
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
) -> tuple[list[tuple], list[dict]]:
|
||||
"""
|
||||
Given a sequence of args and kwargs, split them into a number of chunks
|
||||
according to their respective chunking specs.
|
||||
|
|
@ -347,7 +347,7 @@ def split_args_kwargs_into_chunks(
|
|||
|
||||
|
||||
def merge_chunks(
|
||||
chunks: List[Any],
|
||||
chunks: list[Any],
|
||||
chunk_spec,
|
||||
):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -9,17 +9,7 @@ import re
|
|||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -163,7 +153,7 @@ class _Action(NamedTuple):
|
|||
|
||||
|
||||
def _format_pipeline_order(
|
||||
pipeline_order: Dict[int, List[Optional[_Action]]],
|
||||
pipeline_order: dict[int, list[Optional[_Action]]],
|
||||
error_step_number: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -230,8 +220,8 @@ class _PipelineSchedule(ABC):
|
|||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable[..., torch.Tensor]] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
):
|
||||
# From arguments
|
||||
|
|
@ -256,7 +246,7 @@ class _PipelineSchedule(ABC):
|
|||
self._has_backward = self._loss_fn is not None
|
||||
|
||||
# Holds the losses for each microbatch.
|
||||
self._internal_losses: List[torch.Tensor] = []
|
||||
self._internal_losses: list[torch.Tensor] = []
|
||||
logger.info("Using %s", self.__class__.__name__)
|
||||
|
||||
def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
|
||||
|
|
@ -302,10 +292,10 @@ class _PipelineSchedule(ABC):
|
|||
@abstractmethod
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[List] = None,
|
||||
kwarg_mbs: Optional[List] = None,
|
||||
target_mbs: Optional[List] = None,
|
||||
losses: Optional[List] = None,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
):
|
||||
"""
|
||||
Run one iteration of the pipeline schedule with list of microbatches.
|
||||
|
|
@ -318,7 +308,7 @@ class _PipelineSchedule(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
||||
def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
|
||||
"""
|
||||
Run one iteration of the pipeline schedule with *whole-batch* input.
|
||||
Will chunk the input into microbatches automatically, and go through the
|
||||
|
|
@ -333,10 +323,10 @@ class _PipelineSchedule(ABC):
|
|||
|
||||
def _check_inputs(
|
||||
self,
|
||||
arg_mbs: Optional[List] = None,
|
||||
kwarg_mbs: Optional[List] = None,
|
||||
target_mbs: Optional[List] = None,
|
||||
losses: Optional[List] = None,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
):
|
||||
"""
|
||||
Pre-process/check inputs
|
||||
|
|
@ -375,7 +365,7 @@ class _PipelineSchedule(ABC):
|
|||
def _split_inputs(
|
||||
self,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Splits a full-batch input into chunks (i.e. microbatches) and returns
|
||||
|
|
@ -395,7 +385,7 @@ class _PipelineSchedule(ABC):
|
|||
# Return a list of empty tuples/dicts with matching length as chunks
|
||||
return [()] * self._n_microbatches, [{}] * self._n_microbatches
|
||||
|
||||
def _merge_outputs(self, output_chunks: List[Any]) -> Any:
|
||||
def _merge_outputs(self, output_chunks: list[Any]) -> Any:
|
||||
"""
|
||||
Merge output chunks back to a batch state.
|
||||
If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
|
||||
|
|
@ -406,7 +396,7 @@ class _PipelineSchedule(ABC):
|
|||
)
|
||||
|
||||
|
||||
def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None):
|
||||
def _batch_p2p(p2p_ops: list[dist.P2POp], desc: Optional[str] = None):
|
||||
"""
|
||||
Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
|
||||
"""
|
||||
|
|
@ -418,8 +408,8 @@ def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None):
|
|||
|
||||
|
||||
def _sorted_batch_p2p(
|
||||
p2p_ops: List[dist.P2POp], desc: Optional[str] = None
|
||||
) -> Dict[int, dist.Work]:
|
||||
p2p_ops: list[dist.P2POp], desc: Optional[str] = None
|
||||
) -> dict[int, dist.Work]:
|
||||
"""
|
||||
Sorts the list of P2P ops by the peer rank, and then calls
|
||||
batch_isend_irecv. Return a dictionary of works by peer rank. This function
|
||||
|
|
@ -428,8 +418,8 @@ def _sorted_batch_p2p(
|
|||
# Arrange p2p_ops by peer rank:
|
||||
# int is the peer rank;
|
||||
# List is the list of ops towards the peer
|
||||
ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list)
|
||||
work_by_peer: Dict[int, dist.Work] = {}
|
||||
ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list)
|
||||
work_by_peer: dict[int, dist.Work] = {}
|
||||
if len(p2p_ops) == 0:
|
||||
return work_by_peer
|
||||
|
||||
|
|
@ -461,8 +451,8 @@ class PipelineScheduleSingle(_PipelineSchedule):
|
|||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
):
|
||||
# Init parent
|
||||
|
|
@ -493,7 +483,7 @@ or equal to the number of stages ({self._num_stages})."
|
|||
self._stage._prepare_backward_infra(self._n_microbatches)
|
||||
self._stage_initialized = True
|
||||
|
||||
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
||||
def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
|
||||
"""
|
||||
Run one iteration of the pipeline schedule with *whole-batch* input.
|
||||
Will chunk the input into microbatches automatically, and go through the
|
||||
|
|
@ -535,10 +525,10 @@ class _ScheduleForwardOnly(PipelineScheduleSingle):
|
|||
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[List] = None,
|
||||
kwarg_mbs: Optional[List] = None,
|
||||
target_mbs: Optional[List] = None,
|
||||
losses: Optional[List] = None,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
):
|
||||
"""
|
||||
Run one iteration of the pipeline schedule
|
||||
|
|
@ -553,7 +543,7 @@ class _ScheduleForwardOnly(PipelineScheduleSingle):
|
|||
self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
|
||||
|
||||
# Delay send waits
|
||||
fwd_sends_to_wait: List[dist.Work] = []
|
||||
fwd_sends_to_wait: list[dist.Work] = []
|
||||
|
||||
# Run microbatches
|
||||
for i in range(self._n_microbatches):
|
||||
|
|
@ -586,10 +576,10 @@ class ScheduleGPipe(PipelineScheduleSingle):
|
|||
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[List] = None,
|
||||
kwarg_mbs: Optional[List] = None,
|
||||
target_mbs: Optional[List] = None,
|
||||
losses: Optional[List] = None,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
):
|
||||
"""
|
||||
Run one iteration of the pipeline schedule with list of microbatches.
|
||||
|
|
@ -604,7 +594,7 @@ class ScheduleGPipe(PipelineScheduleSingle):
|
|||
self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
|
||||
|
||||
# Delay send waits
|
||||
fwd_sends_to_wait: List[dist.Work] = []
|
||||
fwd_sends_to_wait: list[dist.Work] = []
|
||||
|
||||
# Run microbatches
|
||||
for i in range(self._n_microbatches):
|
||||
|
|
@ -636,7 +626,7 @@ class ScheduleGPipe(PipelineScheduleSingle):
|
|||
|
||||
# Run backward
|
||||
# Delay send waits
|
||||
bwd_sends_to_wait: List[dist.Work] = []
|
||||
bwd_sends_to_wait: list[dist.Work] = []
|
||||
for i in range(self._n_microbatches):
|
||||
with record_function(f"Backward {i}"):
|
||||
ops = self._stage.get_bwd_recv_ops(i)
|
||||
|
|
@ -677,10 +667,10 @@ class Schedule1F1B(PipelineScheduleSingle):
|
|||
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[List] = None,
|
||||
kwarg_mbs: Optional[List] = None,
|
||||
target_mbs: Optional[List] = None,
|
||||
losses: Optional[List] = None,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
):
|
||||
"""
|
||||
Run one iteration of the pipeline schedule with list of microbatches.
|
||||
|
|
@ -820,9 +810,9 @@ class Schedule1F1B(PipelineScheduleSingle):
|
|||
|
||||
|
||||
def _add_unshard_reshard(
|
||||
compute_actions: List[Optional[_Action]],
|
||||
compute_actions: list[Optional[_Action]],
|
||||
max_active_stages: int = 3,
|
||||
) -> List[_Action]:
|
||||
) -> list[_Action]:
|
||||
"""Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP.
|
||||
|
||||
UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation.
|
||||
|
|
@ -836,11 +826,11 @@ def _add_unshard_reshard(
|
|||
"""
|
||||
|
||||
def next_stage_indices(
|
||||
count: int, next_actions: List[Optional[_Action]]
|
||||
) -> List[int]:
|
||||
count: int, next_actions: list[Optional[_Action]]
|
||||
) -> list[int]:
|
||||
"""Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
|
||||
seen: Set[int] = set()
|
||||
ret: List[int] = []
|
||||
seen: set[int] = set()
|
||||
ret: list[int] = []
|
||||
|
||||
for a in next_actions:
|
||||
if a is not None and a.stage_index not in seen:
|
||||
|
|
@ -850,8 +840,8 @@ def _add_unshard_reshard(
|
|||
break
|
||||
return ret
|
||||
|
||||
active_stages: Set[int] = set()
|
||||
fsdp_aware_actions: List[_Action] = []
|
||||
active_stages: set[int] = set()
|
||||
fsdp_aware_actions: list[_Action] = []
|
||||
|
||||
def _unshard(stage_index: int):
|
||||
active_stages.add(stage_index)
|
||||
|
|
@ -890,8 +880,8 @@ def _add_unshard_reshard(
|
|||
|
||||
|
||||
def _merge_bw(
|
||||
compute_actions: List[Optional[_Action]],
|
||||
) -> List[_Action]:
|
||||
compute_actions: list[Optional[_Action]],
|
||||
) -> list[_Action]:
|
||||
"""Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops.
|
||||
(note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD)
|
||||
|
||||
|
|
@ -925,12 +915,12 @@ def _merge_bw(
|
|||
|
||||
|
||||
def _add_send_recv(
|
||||
compute_actions: Dict[int, List[_Action]],
|
||||
compute_actions: dict[int, list[_Action]],
|
||||
stage_to_rank: Callable[[int], int],
|
||||
num_stages: int,
|
||||
) -> Dict[int, List[_Action]]:
|
||||
comm_actions: Dict[int, List[_Action]] = {rank: [] for rank in compute_actions}
|
||||
prev_actions: Dict[int, Set[_Action]] = {rank: set() for rank in compute_actions}
|
||||
) -> dict[int, list[_Action]]:
|
||||
comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions}
|
||||
prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions}
|
||||
|
||||
def _has_comms(action: _Action) -> bool:
|
||||
if action.computation_type == F:
|
||||
|
|
@ -954,7 +944,7 @@ def _add_send_recv(
|
|||
return send, recv
|
||||
|
||||
def _ready_to_schedule(
|
||||
action: Optional[_Action], prev_actions: Set[_Action]
|
||||
action: Optional[_Action], prev_actions: set[_Action]
|
||||
) -> bool:
|
||||
"""We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
|
||||
This helps ensure a sane (non-hanging) ordering of sends and recvs.
|
||||
|
|
@ -1030,7 +1020,7 @@ def _add_send_recv(
|
|||
|
||||
|
||||
def _validate_schedule(
|
||||
actions: Dict[int, List[Optional[_Action]]],
|
||||
actions: dict[int, list[Optional[_Action]]],
|
||||
pp_group_size: int,
|
||||
num_stages: int,
|
||||
num_microbatches: int,
|
||||
|
|
@ -1043,7 +1033,7 @@ def _validate_schedule(
|
|||
|
||||
# We will count all the actions per stage and ensure they happen in a valid order
|
||||
# (e.g. F before (B, I) before W for a given microbatch)
|
||||
stage_actions: Dict[int, Dict[_ComputationType, Set]] = {
|
||||
stage_actions: dict[int, dict[_ComputationType, set]] = {
|
||||
stage_id: {
|
||||
F: set(),
|
||||
B: set(),
|
||||
|
|
@ -1108,13 +1098,13 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[_PipelineStageBase],
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
|
||||
stage_index_to_group_rank: Optional[Dict[int, int]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
stage_index_to_group_rank: Optional[dict[int, int]] = None,
|
||||
use_full_backward: Optional[bool] = None,
|
||||
scale_grads: bool = True,
|
||||
):
|
||||
|
|
@ -1148,7 +1138,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||
self._should_compute_loss = lambda stage: stage.is_last and has_loss
|
||||
|
||||
# This will be set during init of derived schedules
|
||||
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
|
||||
if use_full_backward is not None:
|
||||
logger.warning(
|
||||
|
|
@ -1199,7 +1189,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||
self._n_microbatches,
|
||||
)
|
||||
|
||||
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
||||
def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
|
||||
"""
|
||||
Run one iteration of the pipeline schedule with *whole-batch* input.
|
||||
Will chunk the input into microbatches automatically, and go through the
|
||||
|
|
@ -1235,10 +1225,10 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[List] = None,
|
||||
kwarg_mbs: Optional[List] = None,
|
||||
target_mbs: Optional[List] = None,
|
||||
losses: Optional[List] = None,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
):
|
||||
"""
|
||||
Operate on the microbatches for looped schedules (multiple stages on each rank).
|
||||
|
|
@ -1253,14 +1243,14 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||
|
||||
# Based on the plan in Step 1 created in __init__:
|
||||
# 2. Perform communication based on the pipeline_order
|
||||
stage_index_to_stage: Dict[int, _PipelineStageBase] = {
|
||||
stage_index_to_stage: dict[int, _PipelineStageBase] = {
|
||||
stage.stage_index: stage for stage in self._stages
|
||||
}
|
||||
|
||||
# determine prev_rank and next_rank based on which ranks are next to
|
||||
# the stages in the pipeline_order
|
||||
all_prev_ranks: Set[int] = set()
|
||||
all_next_ranks: Set[int] = set()
|
||||
all_prev_ranks: set[int] = set()
|
||||
all_next_ranks: set[int] = set()
|
||||
for stage_index in stage_index_to_stage.keys():
|
||||
# TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections)
|
||||
if stage_index > 0:
|
||||
|
|
@ -1271,7 +1261,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||
backward_counter: Counter[int] = Counter()
|
||||
for time_step, action in enumerate(self.pipeline_order[self.rank]):
|
||||
try:
|
||||
ops: List[dist.P2POp] = []
|
||||
ops: list[dist.P2POp] = []
|
||||
if action is not None:
|
||||
computation_type = action.computation_type
|
||||
mb_index = action.microbatch_index
|
||||
|
|
@ -1432,7 +1422,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||
|
||||
def _load_actions(
|
||||
self,
|
||||
actions: Dict[int, List[Optional[_Action]]],
|
||||
actions: dict[int, list[Optional[_Action]]],
|
||||
format: str = "compute_only",
|
||||
):
|
||||
"""
|
||||
|
|
@ -1442,7 +1432,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||
assert (
|
||||
self.stage_index_to_group_rank is not None
|
||||
), "stage_index_to_group_rank is required for PipelineScheduleRuntime"
|
||||
self.pipeline_order_with_comms: Dict[int, List[_Action]] = {}
|
||||
self.pipeline_order_with_comms: dict[int, list[_Action]] = {}
|
||||
if format == "compute_comms":
|
||||
for rank in actions:
|
||||
self.pipeline_order_with_comms[rank] = []
|
||||
|
|
@ -1507,10 +1497,10 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||
|
||||
def _step_microbatches(
|
||||
self,
|
||||
arg_mbs: Optional[List] = None,
|
||||
kwarg_mbs: Optional[List] = None,
|
||||
target_mbs: Optional[List] = None,
|
||||
losses: Optional[List] = None,
|
||||
arg_mbs: Optional[list] = None,
|
||||
kwarg_mbs: Optional[list] = None,
|
||||
target_mbs: Optional[list] = None,
|
||||
losses: Optional[list] = None,
|
||||
):
|
||||
"""
|
||||
Operate on the microbatches for looped schedules (multiple stages on each rank).
|
||||
|
|
@ -1524,7 +1514,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||
|
||||
# Based on the plan in Step 1 created in __init__:
|
||||
# 2. Perform communication based on the pipeline_order
|
||||
stage_index_to_stage: Dict[int, _PipelineStageBase] = {
|
||||
stage_index_to_stage: dict[int, _PipelineStageBase] = {
|
||||
stage.stage_index: stage for stage in self._stages
|
||||
}
|
||||
|
||||
|
|
@ -1533,14 +1523,14 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||
), "Must call _load_actions() before calling _step_microbatches()"
|
||||
|
||||
# recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
|
||||
bwd_recv_ops: Dict[tuple[int, int], Work] = {}
|
||||
fwd_recv_ops: Dict[tuple[int, int], Work] = {}
|
||||
bwd_recv_ops: dict[tuple[int, int], Work] = {}
|
||||
fwd_recv_ops: dict[tuple[int, int], Work] = {}
|
||||
|
||||
# send ops should be waited on before step() exists, mainly for hygeine
|
||||
send_ops: List[Work] = []
|
||||
send_ops: list[Work] = []
|
||||
|
||||
# we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages
|
||||
unshard_ops: Dict[int, UnshardHandle] = {}
|
||||
unshard_ops: dict[int, UnshardHandle] = {}
|
||||
unsharded_stages = set()
|
||||
|
||||
def _assert_unsharded(stage_idx: int):
|
||||
|
|
@ -1751,10 +1741,10 @@ class ScheduleLoopedBFS(PipelineScheduleMulti):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[_PipelineStageBase],
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Union[Callable, _Loss]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
|
|
@ -1768,7 +1758,7 @@ class ScheduleLoopedBFS(PipelineScheduleMulti):
|
|||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
# This will be used to keep track of the current state of the entire pipeline
|
||||
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
||||
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
# ========================================================================
|
||||
for rank in range(self.pp_group_size):
|
||||
rank_ops = self._calculate_single_rank_operations(rank)
|
||||
|
|
@ -1782,7 +1772,7 @@ class ScheduleLoopedBFS(PipelineScheduleMulti):
|
|||
|
||||
# Store the list of operations used for that rank
|
||||
# Pre-padding, rank starts with no-ops based on the warmup.
|
||||
rank_ops: List[Optional[_Action]] = [None for _ in range(rank)]
|
||||
rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
|
||||
|
||||
for stage_index in stage_indices:
|
||||
rank_ops.extend(
|
||||
|
|
@ -1816,13 +1806,13 @@ def _get_1f1b_rank_ops(
|
|||
enable_zero_bubble=False,
|
||||
):
|
||||
# All stages start with handling microbatch 0
|
||||
fwd_stage_mb_index: Dict[int, int] = defaultdict(int)
|
||||
bwd_stage_mb_index: Dict[int, int] = defaultdict(int)
|
||||
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
|
||||
fwd_stage_mb_index: dict[int, int] = defaultdict(int)
|
||||
bwd_stage_mb_index: dict[int, int] = defaultdict(int)
|
||||
weight_stage_mb_index: dict[int, int] = defaultdict(int)
|
||||
|
||||
# Store the list of operations used for that rank
|
||||
# Pre-padding, rank starts with no-ops based on the warmup.
|
||||
rank_ops: List[Optional[_Action]] = [None for _ in range(rank)]
|
||||
rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
|
||||
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
|
||||
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
|
||||
# Formula:
|
||||
|
|
@ -1960,12 +1950,12 @@ class ScheduleInterleaved1F1B(PipelineScheduleMulti):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[_PipelineStageBase],
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
):
|
||||
self.pp_group_size = stages[0].group_size
|
||||
|
|
@ -1991,12 +1981,12 @@ class ScheduleInterleaved1F1B(PipelineScheduleMulti):
|
|||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
# This will be used to keep track of the current state of the entire pipeline
|
||||
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
||||
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
for rank in range(self.pp_group_size):
|
||||
rank_ops = self._calculate_single_rank_operations(rank)
|
||||
self.pipeline_order[rank] = rank_ops
|
||||
|
||||
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
|
||||
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
|
||||
def get_rank_warmup_ops(rank):
|
||||
# Warms up operations for last stage
|
||||
warmups_ops_last_stage = (
|
||||
|
|
@ -2069,12 +2059,12 @@ class ScheduleInterleavedZeroBubble(PipelineScheduleMulti):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[_PipelineStageBase],
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
):
|
||||
# TODO: we don't support Zero Bubble with torch.compile so we
|
||||
|
|
@ -2109,7 +2099,7 @@ stage modules that have used torch.compile"
|
|||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
# This will be used to keep track of the current state of the entire pipeline
|
||||
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
||||
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
for rank in range(self.pp_group_size):
|
||||
rank_ops = self._calculate_single_rank_operations(rank)
|
||||
self.pipeline_order[rank] = rank_ops
|
||||
|
|
@ -2121,7 +2111,7 @@ stage modules that have used torch.compile"
|
|||
self.n_local_stages * self.pp_group_size,
|
||||
)
|
||||
|
||||
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
|
||||
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
|
||||
def get_rank_warmup_ops(rank):
|
||||
# Warms up operations for last stage
|
||||
warmups_ops_last_stage = (
|
||||
|
|
@ -2198,10 +2188,10 @@ stage modules that have used torch.compile"
|
|||
return (stage + 1, op, microbatch) not in seen_ops
|
||||
return False
|
||||
|
||||
seen_ops: Set[tuple[int, _ComputationType, int]] = set()
|
||||
result: Dict[int, List[Optional[_Action]]] = {}
|
||||
next_pointer: Dict[int, int] = {}
|
||||
bubbles_added: Dict[int, int] = {}
|
||||
seen_ops: set[tuple[int, _ComputationType, int]] = set()
|
||||
result: dict[int, list[Optional[_Action]]] = {}
|
||||
next_pointer: dict[int, int] = {}
|
||||
bubbles_added: dict[int, int] = {}
|
||||
total_bubbles_added = 0
|
||||
|
||||
for rank in range(self.pp_group_size):
|
||||
|
|
@ -2212,7 +2202,7 @@ stage modules that have used torch.compile"
|
|||
while True:
|
||||
should_stop = True
|
||||
|
||||
temp_seen_ops: Set[tuple[int, _ComputationType, int]] = set()
|
||||
temp_seen_ops: set[tuple[int, _ComputationType, int]] = set()
|
||||
|
||||
for rank in range(self.pp_group_size):
|
||||
timestamp = next_pointer[rank]
|
||||
|
|
@ -2270,13 +2260,13 @@ class ScheduleZBVZeroBubble(PipelineScheduleMulti):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
stages: List[_PipelineStageBase],
|
||||
stages: list[_PipelineStageBase],
|
||||
n_microbatches: int,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
|
||||
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[Dict[str, Any], tuple[Any]]] = None,
|
||||
stage_index_to_group_rank: Optional[Dict[int, int]] = None,
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
stage_index_to_group_rank: Optional[dict[int, int]] = None,
|
||||
scale_grads: bool = True,
|
||||
):
|
||||
self.pp_group_size = stages[0].group_size
|
||||
|
|
@ -2303,16 +2293,16 @@ class ScheduleZBVZeroBubble(PipelineScheduleMulti):
|
|||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
# This will be used to keep track of the current state of the entire pipeline
|
||||
# pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
|
||||
self.pipeline_order: Dict[int, List[Optional[_Action]]] = {}
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
for rank in range(self.pp_group_size):
|
||||
rank_ops = self._calculate_single_rank_operations(rank)
|
||||
self.pipeline_order[rank] = rank_ops
|
||||
|
||||
def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]:
|
||||
def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
|
||||
# max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least
|
||||
# as large of the number of microbatches needed to fully utilize the pipeline
|
||||
n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches)
|
||||
rank_ops: List[Optional[_Action]] = [None for _ in range(rank)]
|
||||
rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
|
||||
|
||||
# Forward and backward action counts for stage chunk 0 and chunk 1
|
||||
f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0
|
||||
|
|
@ -2469,11 +2459,11 @@ def _simulate_comms_compute(
|
|||
rank: [a for a in pipeline_order[rank] if a is not None]
|
||||
for rank in sorted(pipeline_order)
|
||||
}
|
||||
_schedule: Dict[int, List[_Action | None]] = {
|
||||
_schedule: dict[int, list[_Action | None]] = {
|
||||
rank: [] for rank in sorted(pipeline_order)
|
||||
}
|
||||
|
||||
_prev_ops_rank: Dict[int, Set[_Action]] = {rank: set() for rank in _schedule}
|
||||
_prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule}
|
||||
|
||||
def add_to_schedule(rank: int, action: Optional[_Action]):
|
||||
_schedule[rank].append(action)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import logging
|
||||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -160,10 +160,10 @@ class _PipelineStageBase(ABC):
|
|||
self.dw_builder = dw_builder
|
||||
|
||||
# backward state
|
||||
self.backward_state: Dict[int, tuple[Any, ...]] = {}
|
||||
self.backward_state: dict[int, tuple[Any, ...]] = {}
|
||||
|
||||
# store dw_runner per microbatch_id
|
||||
self.dw_runner: Dict[int, Callable[..., None]] = {}
|
||||
self.dw_runner: dict[int, Callable[..., None]] = {}
|
||||
|
||||
# `group_rank` is rank in process group `group`.
|
||||
self.group_rank = dist.get_rank(self.group)
|
||||
|
|
@ -176,11 +176,11 @@ class _PipelineStageBase(ABC):
|
|||
# Run time states
|
||||
self._outputs_meta: Optional[tuple[torch.Tensor, ...]] = None
|
||||
# map microbatch ID to list of forward tensor args
|
||||
self.fwd_cache: Dict[int, tuple[Any, List[torch.Tensor]]] = {}
|
||||
self.fwd_cache: dict[int, tuple[Any, list[torch.Tensor]]] = {}
|
||||
# map microbatch ID to list of backward grad tensor args
|
||||
self.bwd_cache: Dict[int, tuple[Optional[torch.Tensor], ...]] = {}
|
||||
self.bwd_cache: dict[int, tuple[Optional[torch.Tensor], ...]] = {}
|
||||
# Caching chunk outputs for final output merge or reduction
|
||||
self.output_chunks: List[Any] = []
|
||||
self.output_chunks: list[Any] = []
|
||||
|
||||
# Initialize has_backward to false; this will be set to true if loss
|
||||
# function is passed to pipeline schedule
|
||||
|
|
@ -189,16 +189,16 @@ class _PipelineStageBase(ABC):
|
|||
self.log_prefix = f"[Stage {self.stage_index}]"
|
||||
|
||||
# Forward infra
|
||||
self.args_recv_info: Dict[int, tuple[InputInfo, ...]] = {}
|
||||
self.act_send_info: Dict[int, List] = {}
|
||||
self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {}
|
||||
self.act_send_info: dict[int, list] = {}
|
||||
|
||||
# Backward infra will created lazily
|
||||
self.grad_recv_info: Dict = {}
|
||||
self.grad_send_info: Optional[List] = None
|
||||
self.grad_recv_info: dict = {}
|
||||
self.grad_send_info: Optional[list] = None
|
||||
|
||||
# To be populated later by the Schedule
|
||||
self.chunks: Optional[int] = None
|
||||
self.stage_index_to_group_rank: Dict[int, int] = {
|
||||
self.stage_index_to_group_rank: dict[int, int] = {
|
||||
i: i % self.group_size for i in range(self.num_stages)
|
||||
}
|
||||
|
||||
|
|
@ -258,12 +258,12 @@ class _PipelineStageBase(ABC):
|
|||
|
||||
def _create_grad_send_info(
|
||||
self,
|
||||
args_recv_info: Tuple,
|
||||
) -> List[Optional[int]]:
|
||||
args_recv_info: tuple,
|
||||
) -> list[Optional[int]]:
|
||||
"""
|
||||
Create a list of stage indices to send gradients to.
|
||||
"""
|
||||
grad_send_info: List[Optional[int]] = []
|
||||
grad_send_info: list[Optional[int]] = []
|
||||
|
||||
def map_recv_to_send(a):
|
||||
# Note: we send gradients back to previous stage as long as in
|
||||
|
|
@ -286,7 +286,7 @@ class _PipelineStageBase(ABC):
|
|||
self,
|
||||
num_microbatches: int,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[Any, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -303,19 +303,19 @@ class _PipelineStageBase(ABC):
|
|||
@abstractmethod
|
||||
def _create_grad_recv_info(
|
||||
self,
|
||||
act_send_info: Dict,
|
||||
act_send_info: dict,
|
||||
) -> tuple[_RecvInfo, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_recv_ops(
|
||||
self,
|
||||
recv_infos: tuple[InputInfo, ...],
|
||||
) -> List[dist.P2POp]:
|
||||
) -> list[dist.P2POp]:
|
||||
"""
|
||||
Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`.
|
||||
Returns a list of ops that correspond to the recv infos.
|
||||
"""
|
||||
ops: List[dist.P2POp] = []
|
||||
ops: list[dist.P2POp] = []
|
||||
for info in recv_infos:
|
||||
if not isinstance(info, _RecvInfo):
|
||||
continue
|
||||
|
|
@ -410,7 +410,7 @@ class _PipelineStageBase(ABC):
|
|||
), f"Expected a recv info, got {type(info)}"
|
||||
info.buffer = tensor
|
||||
|
||||
def get_fwd_recv_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]:
|
||||
def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
|
||||
"""
|
||||
Returns a list of ops that are needed to receive the input arguments
|
||||
for this stage.
|
||||
|
|
@ -419,7 +419,7 @@ class _PipelineStageBase(ABC):
|
|||
|
||||
return self._get_recv_ops(recv_infos)
|
||||
|
||||
def get_bwd_recv_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]:
|
||||
def get_bwd_recv_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]:
|
||||
"""
|
||||
Returns a list of ops that are needed to receive the gradients
|
||||
for this stage.
|
||||
|
|
@ -430,7 +430,7 @@ class _PipelineStageBase(ABC):
|
|||
recv_infos = self.grad_recv_info[bwd_chunk_id]
|
||||
return self._get_recv_ops(recv_infos)
|
||||
|
||||
def get_fwd_send_ops(self, fwd_chunk_id: int) -> List[dist.P2POp]:
|
||||
def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
|
||||
"""
|
||||
Get the activation send ops for current stage's forward.
|
||||
"""
|
||||
|
|
@ -439,7 +439,7 @@ class _PipelineStageBase(ABC):
|
|||
# `act_send_info`
|
||||
output_tuple = output if type(output) is tuple else (output,)
|
||||
|
||||
ops: List[dist.P2POp] = []
|
||||
ops: list[dist.P2POp] = []
|
||||
|
||||
for idx, out in enumerate(output_tuple):
|
||||
dst_stages = self.act_send_info[idx]
|
||||
|
|
@ -462,7 +462,7 @@ class _PipelineStageBase(ABC):
|
|||
|
||||
return ops
|
||||
|
||||
def get_bwd_send_ops(self, bwd_chunk_id: int) -> List[dist.P2POp]:
|
||||
def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]:
|
||||
"""
|
||||
Get the gradient send ops for current stage's backward.
|
||||
"""
|
||||
|
|
@ -479,7 +479,7 @@ class _PipelineStageBase(ABC):
|
|||
# `grad_send_info` is a mirror of `args_recv_info`
|
||||
self.grad_send_info = self._create_grad_send_info(self.args_recv_info[0])
|
||||
|
||||
ops: List[dist.P2POp] = []
|
||||
ops: list[dist.P2POp] = []
|
||||
grads_input = self.bwd_cache.pop(bwd_chunk_id)
|
||||
for grad, grad_recv_stage in zip(grads_input, self.grad_send_info):
|
||||
if isinstance(grad, torch.Tensor) and grad_recv_stage is not None:
|
||||
|
|
@ -593,9 +593,9 @@ class _PipelineStageBase(ABC):
|
|||
def backward_maybe_with_nosync(
|
||||
self,
|
||||
backward_type,
|
||||
bwd_kwargs: Dict,
|
||||
bwd_kwargs: dict,
|
||||
last_backward: bool = False,
|
||||
) -> tuple[tuple[Optional[torch.Tensor], ...], Optional[List[Dict[str, Any]]]]:
|
||||
) -> tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]]:
|
||||
"""
|
||||
Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the
|
||||
other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but
|
||||
|
|
@ -607,7 +607,7 @@ class _PipelineStageBase(ABC):
|
|||
backward_type,
|
||||
) -> Callable[
|
||||
[],
|
||||
tuple[tuple[Optional[torch.Tensor], ...], Optional[List[Dict[str, Any]]]],
|
||||
tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]],
|
||||
]:
|
||||
if backward_type == "full":
|
||||
return lambda: (
|
||||
|
|
@ -686,7 +686,7 @@ class _PipelineStageBase(ABC):
|
|||
self,
|
||||
fwd_chunk_id: int,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Perform forward pass on the stage with one microbatch.
|
||||
|
|
@ -817,7 +817,7 @@ class _PipelineStageBase(ABC):
|
|||
"full", bwd_kwargs, last_backward=last_backward
|
||||
)
|
||||
else:
|
||||
param_groups: List[Dict[str, Any]] | None = None
|
||||
param_groups: list[dict[str, Any]] | None = None
|
||||
# Skip the backward for the first stage since we will perform the weight update with
|
||||
# autograd.backward in backward_weight_one_chunk
|
||||
if not self.is_first:
|
||||
|
|
@ -986,7 +986,7 @@ class _PipelineStage(_PipelineStageBase):
|
|||
)
|
||||
|
||||
# Create mapping from stage name to stage index
|
||||
self.submod_to_stage_index: Dict[str, int] = {}
|
||||
self.submod_to_stage_index: dict[str, int] = {}
|
||||
for i, node in enumerate(submod_nodes):
|
||||
self.submod_to_stage_index.setdefault(node.name, i)
|
||||
|
||||
|
|
@ -1010,7 +1010,7 @@ class _PipelineStage(_PipelineStageBase):
|
|||
self,
|
||||
num_microbatches: int,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
Create send/recv infrastructures for activations (during forward)
|
||||
|
|
@ -1084,7 +1084,7 @@ class _PipelineStage(_PipelineStageBase):
|
|||
buffer,
|
||||
)
|
||||
|
||||
args_recv_info: List[InputInfo] = []
|
||||
args_recv_info: list[InputInfo] = []
|
||||
# Filter out placeholder nodes from `self.submod` (a GraphModule)
|
||||
placeholders = filter( # type: ignore[var-annotated]
|
||||
lambda node: node.op == "placeholder", self.submod.graph.nodes # type: ignore[arg-type, union-attr]
|
||||
|
|
@ -1134,7 +1134,7 @@ class _PipelineStage(_PipelineStageBase):
|
|||
be consumed by multiple stages.
|
||||
"""
|
||||
# Output index: List of receiver ranks
|
||||
act_send_info: Dict[int, List] = {}
|
||||
act_send_info: dict[int, list] = {}
|
||||
out_idx = 0
|
||||
|
||||
for user in self.node.users:
|
||||
|
|
@ -1171,13 +1171,13 @@ class _PipelineStage(_PipelineStageBase):
|
|||
|
||||
def _create_grad_recv_info(
|
||||
self,
|
||||
act_send_info: Dict,
|
||||
act_send_info: dict,
|
||||
) -> tuple[_RecvInfo, ...]:
|
||||
"""
|
||||
Create a tuple of `_RecvInfo` for gradients.
|
||||
"""
|
||||
# Dict[output_index, _RecvInfo]
|
||||
grad_recv_info: Dict[int, _RecvInfo] = {}
|
||||
grad_recv_info: dict[int, _RecvInfo] = {}
|
||||
output_node = self._get_output_node()
|
||||
|
||||
# The output node may take multiple args, meaning the submod having multiple output values.
|
||||
|
|
@ -1275,7 +1275,7 @@ class PipelineStage(_PipelineStageBase):
|
|||
dw_builder: Optional[Callable[[], Callable[..., None]]] = None,
|
||||
):
|
||||
super().__init__(submodule, stage_index, num_stages, device, group, dw_builder)
|
||||
self.inputs: Optional[List[torch.Tensor]] = None
|
||||
self.inputs: Optional[list[torch.Tensor]] = None
|
||||
self.inputs_meta: Optional[tuple[torch.Tensor, ...]] = None
|
||||
# Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) becuase it
|
||||
# might be breaking for existing users.
|
||||
|
|
@ -1313,7 +1313,7 @@ class PipelineStage(_PipelineStageBase):
|
|||
)
|
||||
|
||||
# these are the buffers used in backwards send/recv, they are allocated later
|
||||
self.outputs_grad: List[torch.Tensor] = []
|
||||
self.outputs_grad: list[torch.Tensor] = []
|
||||
|
||||
def stage_global_rank(peer_rank):
|
||||
return (
|
||||
|
|
@ -1342,7 +1342,7 @@ class PipelineStage(_PipelineStageBase):
|
|||
def _shape_inference(
|
||||
self,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
|
@ -1443,7 +1443,7 @@ class PipelineStage(_PipelineStageBase):
|
|||
self,
|
||||
num_microbatches: int,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[Any, ...]:
|
||||
# TODO move self.device to an argument from step API (from its input tensors)?
|
||||
assert num_microbatches is not None, "TODO fix num_microbatches"
|
||||
|
|
@ -1481,7 +1481,7 @@ class PipelineStage(_PipelineStageBase):
|
|||
|
||||
# Send info during forward for each activation
|
||||
# only need the rank that is being sent to
|
||||
self.act_send_info: Dict[int, List] = {}
|
||||
self.act_send_info: dict[int, list] = {}
|
||||
|
||||
for idx in range(len(self.get_outputs_meta())):
|
||||
# We assume we always send to stage + 1
|
||||
|
|
@ -1494,7 +1494,7 @@ class PipelineStage(_PipelineStageBase):
|
|||
|
||||
def _create_grad_recv_info(
|
||||
self,
|
||||
act_send_info: Dict,
|
||||
act_send_info: dict,
|
||||
) -> tuple[_RecvInfo, ...]:
|
||||
grad_recv_info: tuple[_RecvInfo, ...] = ()
|
||||
if not self.is_last:
|
||||
|
|
|
|||
|
|
@ -9,15 +9,16 @@ except ImportError as e:
|
|||
import numbers
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from datetime import timedelta
|
||||
from typing import Callable, Dict, Iterator, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch.distributed import FileStore, Store, TCPStore
|
||||
|
||||
from .constants import default_pg_timeout
|
||||
|
||||
|
||||
_rendezvous_handlers: Dict[str, Callable[..., Iterator[tuple[Store, int, int]]]] = {}
|
||||
_rendezvous_handlers: dict[str, Callable[..., Iterator[tuple[Store, int, int]]]] = {}
|
||||
|
||||
__all__ = ["register_rendezvous_handler", "rendezvous"]
|
||||
|
||||
|
|
@ -54,14 +55,14 @@ def register_rendezvous_handler(scheme, handler):
|
|||
|
||||
# Query will have format "rank=0&world_size=1" and is
|
||||
# converted into {"rank": 0, "world_size": 1}
|
||||
def _query_to_dict(query: str) -> Dict[str, str]:
|
||||
def _query_to_dict(query: str) -> dict[str, str]:
|
||||
return {
|
||||
pair[0]: pair[1]
|
||||
for pair in (pair.split("=") for pair in filter(None, query.split("&")))
|
||||
}
|
||||
|
||||
|
||||
def _get_use_libuv_from_query_dict(query_dict: Dict[str, str]) -> bool:
|
||||
def _get_use_libuv_from_query_dict(query_dict: dict[str, str]) -> bool:
|
||||
# libuv is the default backend for TCPStore. To enable the non-libuv backend,
|
||||
# user can explicitly specify ``use_libuv=0`` in the URL parameter.
|
||||
return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1"
|
||||
|
|
|
|||
|
|
@ -3,8 +3,9 @@ import logging
|
|||
import os
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from typing import Generator, Tuple
|
||||
from typing import Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import functools
|
|||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar
|
||||
from typing import Any, Generic, TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
from torch._C._distributed_rpc import (
|
||||
|
|
@ -115,9 +115,9 @@ class AllGatherStates:
|
|||
|
||||
# States used by `def _all_gather()`.
|
||||
# `_ALL_WORKER_NAMES` is initialized on initializing RPC layer.
|
||||
_ALL_WORKER_NAMES: Set[Any] = set()
|
||||
_ALL_WORKER_NAMES: set[Any] = set()
|
||||
_all_gather_dict_lock = threading.RLock()
|
||||
_all_gather_sequence_id: Dict[str, int] = {}
|
||||
_all_gather_sequence_id: dict[str, int] = {}
|
||||
_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(
|
||||
AllGatherStates
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import collections
|
||||
import enum
|
||||
from typing import cast, Dict, List, Set
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -163,8 +163,8 @@ def _tensorpipe_validate_devices(devices, device_count):
|
|||
def _tensorpipe_exchange_and_check_all_device_maps(
|
||||
my_name, my_device_count, my_device_maps, my_devices, group
|
||||
):
|
||||
gathered: List[
|
||||
tuple[str, int, Dict[str, Dict[torch.device, torch.device]], List[torch.device]]
|
||||
gathered: list[
|
||||
tuple[str, int, dict[str, dict[torch.device, torch.device]], list[torch.device]]
|
||||
] = [("", 0, {}, []) for _ in range(group.size())]
|
||||
dist.all_gather_object(
|
||||
gathered, (my_name, my_device_count, my_device_maps, my_devices), group
|
||||
|
|
@ -253,7 +253,7 @@ def _validate_device_maps(
|
|||
|
||||
def _create_device_list(my_devices, my_device_maps, reverse_device_maps):
|
||||
if not my_devices:
|
||||
devices_set: Set[torch.device] = set()
|
||||
devices_set: set[torch.device] = set()
|
||||
for map_ in my_device_maps.values():
|
||||
devices_set.update(map_.keys())
|
||||
for map_ in reverse_device_maps.values():
|
||||
|
|
@ -265,7 +265,7 @@ def _create_device_list(my_devices, my_device_maps, reverse_device_maps):
|
|||
|
||||
|
||||
def _create_reverse_mapping(my_name, all_names, all_device_maps):
|
||||
reverse_device_maps: Dict[str, Dict[torch.device, torch.device]] = {}
|
||||
reverse_device_maps: dict[str, dict[torch.device, torch.device]] = {}
|
||||
for node in all_names:
|
||||
if my_name in all_device_maps[node]:
|
||||
reverse_device_maps[node] = {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from datetime import timedelta
|
||||
from typing import List
|
||||
|
||||
from torch._C._distributed_rpc import (
|
||||
_DEFAULT_INIT_METHOD,
|
||||
|
|
@ -22,4 +21,4 @@ DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2**31 - 1)
|
|||
# Value indicating that timeout is not set for RPC call, and the default should be used.
|
||||
UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
|
||||
|
|
@ -23,10 +23,10 @@ def _to_device(device: DeviceType) -> torch.device:
|
|||
|
||||
|
||||
def _to_device_map(
|
||||
device_map: Dict[DeviceType, DeviceType]
|
||||
) -> Dict[torch.device, torch.device]:
|
||||
full_device_map: Dict[torch.device, torch.device] = {}
|
||||
reverse_map: Dict[torch.device, torch.device] = {}
|
||||
device_map: dict[DeviceType, DeviceType]
|
||||
) -> dict[torch.device, torch.device]:
|
||||
full_device_map: dict[torch.device, torch.device] = {}
|
||||
reverse_map: dict[torch.device, torch.device] = {}
|
||||
for k, v in device_map.items():
|
||||
k, v = torch.device(k), torch.device(v)
|
||||
if v in reverse_map:
|
||||
|
|
@ -39,7 +39,7 @@ def _to_device_map(
|
|||
return full_device_map
|
||||
|
||||
|
||||
def _to_device_list(devices: List[DeviceType]) -> List[torch.device]:
|
||||
def _to_device_list(devices: list[DeviceType]) -> list[torch.device]:
|
||||
return list(map(_to_device, devices))
|
||||
|
||||
|
||||
|
|
@ -83,10 +83,10 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
|
|||
num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
|
||||
rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
|
||||
init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
|
||||
device_maps: Optional[Dict[str, Dict[DeviceType, DeviceType]]] = None,
|
||||
devices: Optional[List[DeviceType]] = None,
|
||||
_transports: Optional[List] = None,
|
||||
_channels: Optional[List] = None,
|
||||
device_maps: Optional[dict[str, dict[DeviceType, DeviceType]]] = None,
|
||||
devices: Optional[list[DeviceType]] = None,
|
||||
_transports: Optional[list] = None,
|
||||
_channels: Optional[list] = None,
|
||||
):
|
||||
full_device_maps = (
|
||||
{}
|
||||
|
|
@ -104,7 +104,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
|
|||
full_device_list,
|
||||
)
|
||||
|
||||
def set_device_map(self, to: str, device_map: Dict[DeviceType, DeviceType]):
|
||||
def set_device_map(self, to: str, device_map: dict[DeviceType, DeviceType]):
|
||||
r"""
|
||||
Set device mapping between each RPC caller and callee pair. This
|
||||
function can be called multiple times to incrementally add
|
||||
|
|
@ -162,7 +162,7 @@ class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
|
|||
|
||||
super()._set_device_map(to, full_device_map)
|
||||
|
||||
def set_devices(self, devices: List[DeviceType]):
|
||||
def set_devices(self, devices: list[DeviceType]):
|
||||
r"""
|
||||
Set local devices used by the TensorPipe RPC agent. When processing
|
||||
CUDA RPC requests, the TensorPipe RPC agent will properly synchronize
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
import itertools
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler_legacy import profile
|
||||
|
|
@ -13,7 +12,7 @@ from . import (
|
|||
)
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
|
||||
class _server_process_global_profile(profile):
|
||||
|
|
|
|||
|
|
@ -398,7 +398,7 @@ import sys
|
|||
import uuid
|
||||
from argparse import ArgumentParser, REMAINDER
|
||||
from importlib import metadata
|
||||
from typing import Callable, List, Optional, Set, Type, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.argparse_util import check_env, env
|
||||
|
|
@ -736,7 +736,7 @@ def get_use_env(args) -> bool:
|
|||
return args.use_env
|
||||
|
||||
|
||||
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
|
||||
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]:
|
||||
"""
|
||||
Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
|
||||
Provides plugin mechanism to provide custom implementation of LogsSpecs.
|
||||
|
|
@ -770,7 +770,7 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
|
|||
return logs_specs_cls
|
||||
|
||||
|
||||
def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str]]:
|
||||
def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]:
|
||||
# If ``args`` not passed, defaults to ``sys.argv[:1]``
|
||||
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
|
||||
assert 0 < min_nodes <= max_nodes
|
||||
|
|
@ -810,7 +810,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str
|
|||
|
||||
rdzv_endpoint = get_rdzv_endpoint(args)
|
||||
|
||||
ranks: Optional[Set[int]] = None
|
||||
ranks: Optional[set[int]] = None
|
||||
if args.local_ranks_filter:
|
||||
try:
|
||||
ranks = set(map(int, args.local_ranks_filter.split(",")))
|
||||
|
|
@ -820,7 +820,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], List[str
|
|||
"--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
|
||||
) from e
|
||||
|
||||
logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
|
||||
logs_specs_cls: type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
|
||||
logs_specs = logs_specs_cls(
|
||||
log_dir=args.log_dir,
|
||||
redirects=Std.from_str(args.redirects),
|
||||
|
|
|
|||
|
|
@ -1,18 +1,9 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import dataclasses
|
||||
import traceback
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Container,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
OrderedDict,
|
||||
overload,
|
||||
Set,
|
||||
TypeVar,
|
||||
)
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Container
|
||||
from typing import Any, Callable, Optional, overload, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -43,8 +34,8 @@ def _pack_kwargs(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], tuple[str,
|
|||
kwarg keys. The second tuple element gives the kwarg keys.
|
||||
The second tuple element's length is at most the first tuple element's length.
|
||||
"""
|
||||
kwarg_keys: List[str] = []
|
||||
flat_args: List[Any] = list(args)
|
||||
kwarg_keys: list[str] = []
|
||||
flat_args: list[Any] = list(args)
|
||||
for k, v in kwargs.items():
|
||||
kwarg_keys.append(k)
|
||||
flat_args.append(v)
|
||||
|
|
@ -75,7 +66,7 @@ def _cast_forward_inputs(
|
|||
|
||||
def _unpack_kwargs(
|
||||
flat_args: tuple[Any, ...], kwarg_keys: tuple[str, ...]
|
||||
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
|
||||
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||
"""See _pack_kwargs."""
|
||||
assert len(kwarg_keys) <= len(
|
||||
flat_args
|
||||
|
|
@ -94,7 +85,7 @@ T = TypeVar("T", torch.Tensor, PackedSequence)
|
|||
@overload
|
||||
def _recursive_to(
|
||||
inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool
|
||||
) -> List[S]:
|
||||
) -> list[S]:
|
||||
...
|
||||
|
||||
|
||||
|
|
@ -264,10 +255,10 @@ def _apply_to_tensors(fn, container):
|
|||
|
||||
def _to_kwargs(
|
||||
inputs: tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]],
|
||||
kwargs: Optional[dict[str, Any]],
|
||||
target_device: torch.device,
|
||||
use_side_stream_for_tensor_copies: bool,
|
||||
) -> tuple[tuple[Any, ...], tuple[Dict[str, Any], ...]]:
|
||||
) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]:
|
||||
moved_inputs = (
|
||||
_recursive_to(inputs, target_device, use_side_stream_for_tensor_copies)
|
||||
if inputs
|
||||
|
|
@ -287,7 +278,7 @@ def _to_kwargs(
|
|||
|
||||
def _verify_param_shape_across_processes(
|
||||
process_group: dist.ProcessGroup,
|
||||
tensors: List[torch.Tensor],
|
||||
tensors: list[torch.Tensor],
|
||||
logger: Optional["dist.Logger"] = None,
|
||||
):
|
||||
return dist._verify_params_across_processes(process_group, tensors, logger)
|
||||
|
|
@ -309,7 +300,7 @@ def _sync_module_states(
|
|||
parameter shapes are consistent before running the synchronization. This can
|
||||
be checked with ``_verify_param_shape_across_processes``.
|
||||
"""
|
||||
module_states: List[torch.Tensor] = []
|
||||
module_states: list[torch.Tensor] = []
|
||||
for name, param in module.named_parameters():
|
||||
if name not in params_and_buffers_to_ignore:
|
||||
module_states.append(param.detach())
|
||||
|
|
@ -324,7 +315,7 @@ def _sync_module_states(
|
|||
|
||||
def _sync_params_and_buffers(
|
||||
process_group: dist.ProcessGroup,
|
||||
module_states: List[torch.Tensor],
|
||||
module_states: list[torch.Tensor],
|
||||
broadcast_bucket_size: int,
|
||||
src: int,
|
||||
) -> None:
|
||||
|
|
@ -336,7 +327,7 @@ def _sync_params_and_buffers(
|
|||
|
||||
|
||||
def _replace_by_prefix(
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
old_prefix: str,
|
||||
new_prefix: str,
|
||||
) -> None:
|
||||
|
|
@ -363,15 +354,15 @@ def _data_ptr_allocated(tensor: torch.Tensor) -> bool:
|
|||
return tensor.untyped_storage().data_ptr() > 0
|
||||
|
||||
|
||||
def _get_root_modules(modules: List[nn.Module]) -> List[nn.Module]:
|
||||
def _get_root_modules(modules: list[nn.Module]) -> list[nn.Module]:
|
||||
"""
|
||||
Returns the modules in ``modules`` that are root modules (i.e.
|
||||
parent-less) with respect to the set ``modules``. In other words, these
|
||||
are the modules in ``modules`` that are the not child of any other
|
||||
module in ``modules``.
|
||||
"""
|
||||
root_modules: List[nn.Module] = []
|
||||
module_to_modules: Dict[nn.Module, Set[nn.Module]] = {
|
||||
root_modules: list[nn.Module] = []
|
||||
module_to_modules: dict[nn.Module, set[nn.Module]] = {
|
||||
module: set(module.modules()) for module in modules
|
||||
}
|
||||
for candidate_module in modules:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user