PEP585 update - torch/distributed/tensor (#145141)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145141
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein 2025-01-17 22:11:10 -08:00 committed by PyTorch MergeBot
parent 4f8237dbad
commit c95efc37ba
36 changed files with 269 additions and 260 deletions

View File

@ -3,7 +3,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import inspect
import warnings
from typing import Any, Callable, cast, Optional, Sequence
from collections.abc import Sequence
from typing import Any, Callable, cast, Optional
from typing_extensions import deprecated
import torch

View File

@ -3,7 +3,7 @@ import logging
import math
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional
from typing import Optional
import torch
import torch.distributed._functional_collectives as funcol
@ -73,7 +73,7 @@ def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
def mesh_scatter(
output: torch.Tensor,
scatter_list: List[torch.Tensor],
scatter_list: list[torch.Tensor],
mesh: DeviceMesh,
mesh_dim: int = 0,
async_op: bool = False,
@ -195,8 +195,8 @@ def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Ten
def fill_empty_tensor_to_shards(
shards: List[torch.Tensor], shard_dim: int, num_empty_tensors: int
) -> List[torch.Tensor]:
shards: list[torch.Tensor], shard_dim: int, num_empty_tensors: int
) -> list[torch.Tensor]:
if num_empty_tensors == 0:
return shards
tensor_size = list(shards[0].size())
@ -244,9 +244,9 @@ class MeshTopoInfo:
"""
mesh: DeviceMesh
mesh_dim_devices: List[int]
mesh_dim_bandwidth: List[float]
mesh_dim_latency: List[float]
mesh_dim_devices: list[int]
mesh_dim_bandwidth: list[float]
mesh_dim_latency: list[float]
@staticmethod
@lru_cache(None)

View File

@ -4,7 +4,8 @@ import functools
import logging
import operator
import warnings
from typing import cast, Dict, List, Optional, Sequence, TYPE_CHECKING
from collections.abc import Sequence
from typing import cast, Optional, TYPE_CHECKING
import torch
import torch.distributed as dist
@ -44,7 +45,7 @@ logger = logging.getLogger(__name__)
def decompose_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> object:
"""
Decomposes a op to core ATen op, this handler is mostly here
@ -60,7 +61,7 @@ def decompose_handler(
def is_same_size_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> bool:
lhs = cast(torch.Tensor, args[0])
rhs = cast(torch.Tensor, args[1])
@ -70,11 +71,11 @@ def is_same_size_handler(
def found_inf_reduce_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> None:
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
local_tensor_args = pytree.tree_unflatten(
cast(List[object], op_info.local_args), op_info.args_tree_spec # type: ignore[arg-type]
cast(list[object], op_info.local_args), op_info.args_tree_spec # type: ignore[arg-type]
)
local_tensor_args = cast(tuple[object, ...], local_tensor_args)
op_call(*local_tensor_args, **op_info.local_kwargs)
@ -154,7 +155,7 @@ class OpDispatcher:
self,
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> object:
"""
Main dispatching logic
@ -185,7 +186,7 @@ class OpDispatcher:
local_tensor_args = (
pytree.tree_unflatten(
cast(List[object], op_info.local_args), op_info.args_tree_spec
cast(list[object], op_info.local_args), op_info.args_tree_spec
)
if op_info.args_tree_spec
else op_info.local_args
@ -251,7 +252,7 @@ class OpDispatcher:
local_results = [
default_tensor(s) if s is not None else None for s in spec
]
assert isinstance(local_results, List)
assert isinstance(local_results, list)
if None in local_results:
ret_type = str(ret_list[0].type)
raise NotImplementedError(
@ -309,7 +310,7 @@ class OpDispatcher:
else:
flatten_args_schema_to_reshard = suggested_input_schema.args_schema
new_local_args: List[object] = []
new_local_args: list[object] = []
for i, arg_spec in enumerate(op_info.flat_args_schema):
reshard_arg_spec = flatten_args_schema_to_reshard[i]
if isinstance(arg_spec, DTensorSpec):
@ -330,7 +331,7 @@ class OpDispatcher:
self,
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> OpInfo:
# get runtime schema info to determine whether to use pytree to flatten inputs
runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
@ -344,10 +345,10 @@ class OpDispatcher:
else:
args_list, args_spec = args, None
args_schema: List[object] = []
kwargs_schema: Dict[str, object] = {}
local_args: List[object] = []
local_kwargs: Dict[str, object] = {}
args_schema: list[object] = []
kwargs_schema: dict[str, object] = {}
local_args: list[object] = []
local_kwargs: dict[str, object] = {}
mesh: Optional[DeviceMesh] = None
for arg in args_list:

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, cast, List, NamedTuple, Optional
from typing import Any, cast, NamedTuple, Optional
import torch
from torch.distributed.device_mesh import DeviceMesh
@ -133,7 +133,7 @@ class DTensorSpec:
return self.mesh
@property
def dim_map(self) -> List[int]:
def dim_map(self) -> list[int]:
"""
dim_map is a property we derive from `placements` of
the distributed tensor. It simply return a list of ints
@ -170,7 +170,7 @@ class DTensorSpec:
return r
@property
def num_shards_map(self) -> List[int]:
def num_shards_map(self) -> list[int]:
"""
dim_map is a property we derive from `placements` of
the distributed tensor. Unlike `dim_map`, `num_shards_map`
@ -193,7 +193,7 @@ class DTensorSpec:
return r
@property
def sums(self) -> List[int]:
def sums(self) -> list[int]:
"""
sums is a property we derive from `placements` of the
distributed tensor. It simply return a list of ints where
@ -209,8 +209,8 @@ class DTensorSpec:
def from_dim_map(
cls,
mesh: DeviceMesh,
dim_map: List[int],
sums: List[int],
dim_map: list[int],
sums: list[int],
tensor_meta: Optional[TensorMeta] = None,
) -> "DTensorSpec":
"""
@ -228,7 +228,7 @@ class DTensorSpec:
a class:`DTensorSpec` object
"""
# by default replicate on device mesh dims
placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)]
placements: list[Placement] = [Replicate() for _ in range(mesh.ndim)]
# find all mesh dims that need pending reductions
for s in sums:

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
from collections.abc import Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Optional, Union
import torch
from torch._ops import OpOverload
@ -22,9 +23,9 @@ except ImportError:
# Common type aliases
ArgsType = tuple[object, ...]
KwargsType = Dict[str, object]
KwargsType = dict[str, object]
PlacementList = List[Optional[Placement]]
PlacementList = list[Optional[Placement]]
# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
# be the same set of possibilities.
@ -86,7 +87,7 @@ class PlacementStrategy:
# we need a nested list to record the cost for each
# operand of this operator, and for each operand of
# this operator it might have multiple placement strategies
redistribute_cost: Optional[List[List[float]]] = None
redistribute_cost: Optional[list[list[float]]] = None
@cached_property
def output_spec(self) -> DTensorSpec:
@ -130,9 +131,9 @@ class OpStrategy(StrategyType):
OpStrategy that consists of a list of placement strategies associated with the op
"""
def __init__(self, strategies: List[PlacementStrategy]) -> None:
def __init__(self, strategies: list[PlacementStrategy]) -> None:
super().__init__()
self.strategies: List[PlacementStrategy] = strategies
self.strategies: list[PlacementStrategy] = strategies
def __str__(self) -> str:
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
@ -203,7 +204,7 @@ class RuntimeSchemaInfo:
# Note that only a few ops need this information, e.g. view, transpose, var.dim, etc.
static_argnum: int = 100
# This static_kwargkey records static kwarg names which would affect sharding prop
static_kwargkey: Optional[List[str]] = None
static_kwargkey: Optional[list[str]] = None
# each op can decide if it wants to use pytree flatten/unflatten during operator
# eager execution, by default we don't need to do flatten/unflatten, only if the
# op indicate it needs to, this is to accelerate eager performance.
@ -270,7 +271,7 @@ class OpSchema:
)
def __str__(self) -> str:
args_schema: List[str] = []
args_schema: list[str] = []
mesh_shape = None
for arg in self.args_schema:
if isinstance(arg, DTensorSpec):
@ -404,7 +405,7 @@ class OpSchema:
def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
suggestion_args_spec = self.args_spec
new_arg_schema: List[object] = []
new_arg_schema: list[object] = []
idx_of_args_spec = 0
if (
origin_schema.schema_info is not None
@ -448,9 +449,9 @@ class OpInfo:
mesh: DeviceMesh
schema: OpSchema
flat_args_schema: List[object]
flat_args_schema: list[object]
local_args: Sequence[object]
local_kwargs: Dict[str, object]
local_kwargs: dict[str, object]
args_tree_spec: Optional[TreeSpec] = None
# the output sharding info

View File

@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import string
from typing import cast, Dict, List, Optional
from typing import cast, Optional
import torch
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
@ -20,12 +20,12 @@ def _replace_char_in_str(string: str, new_char: str, idx: int) -> str:
def _gen_reshard_suggestions(
op_schema: OpSchema,
input_dims: List[str],
input_dims: list[str],
input_specs: tuple[DTensorSpec, ...],
dim_to_sharding: Dict[str, int],
pending_sum: List[int],
dim_to_sharding: dict[str, int],
pending_sum: list[int],
) -> OutputSharding:
suggested_arg_specs: List[DTensorSpec] = []
suggested_arg_specs: list[DTensorSpec] = []
for input_dim, input_spec in zip(input_dims, input_specs):
dim_map = [dim_to_sharding[dim] for dim in input_dim]
suggested_arg_specs.append(
@ -49,7 +49,7 @@ def einop_rule(
op_schema: OpSchema,
*,
linearity: bool = False,
enforce_sharding: Optional[Dict[str, int]] = None,
enforce_sharding: Optional[dict[str, int]] = None,
) -> OutputSharding:
"""
Propagate the sharding of inputs to output for ops whose data moves according to einsum notation.
@ -77,12 +77,12 @@ def einop_rule(
# NOTE: only support single output unless needed in future
output_dim = output_dims[0]
dim_to_sharding: Dict[str, int] = {}
dim_to_size: Dict[str, int] = {}
dim_to_sharding: dict[str, int] = {}
dim_to_size: dict[str, int] = {}
# record pending sum, key is mesh dimension, value is pending sum
# counter across input specs
pending_sums_counter: Dict[int, int] = {}
seen_shardings: Dict[int, str] = {}
pending_sums_counter: dict[int, int] = {}
seen_shardings: dict[int, str] = {}
needs_reshard = False
def merge_sharding(dim: str, a: int, b: int) -> int:
@ -240,7 +240,7 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi
input_specs = op_schema.args_spec
max_dim = max(input.ndim for input in input_specs)
dimchars = []
singleton_counter: List[int] = [0] * max_dim
singleton_counter: list[int] = [0] * max_dim
for input in input_specs:
start_dim = max_dim - input.ndim
p = alphabet[start_dim:max_dim]
@ -271,7 +271,7 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi
fmt = f"{','.join(p for p in dimchars)}->{out_dimchars}"
enforce_sharding: Dict[str, int] = {}
enforce_sharding: dict[str, int] = {}
if _is_inplace_op(op_schema.op):
# inplace op should keep the input sharding it writes to
for out_dimchar, mesh_dim in zip(out_dimchars, input_specs[0].dim_map):

View File

@ -1,6 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from typing import List
import torch
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
@ -32,9 +31,9 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding:
assert weight_spec.tensor_meta is not None
in_shape = input_spec.tensor_meta.shape
weight_shape = weight_spec.tensor_meta.shape
assert isinstance(stride, List)
assert isinstance(padding, List)
assert isinstance(dilation, List)
assert isinstance(stride, list)
assert isinstance(padding, list)
assert isinstance(dilation, list)
assert isinstance(weight_shape, torch.Size)
N, H_in, W_in = in_shape[0], in_shape[2], in_shape[3]
C_out = weight_shape[0]
@ -84,7 +83,7 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
assert isinstance(grad_output_spec, DTensorSpec)
assert isinstance(input_spec, DTensorSpec)
assert isinstance(weight_spec, DTensorSpec)
assert isinstance(bias_shape_opt, List)
assert isinstance(bias_shape_opt, list)
assert input_spec.tensor_meta is not None
weight_tensor_meta = weight_spec.tensor_meta
bias_tensor_meta = TensorMeta(

View File

@ -1,6 +1,5 @@
import itertools
from dataclasses import dataclass
from typing import List, Set
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec
@ -15,13 +14,13 @@ from torch.distributed.tensor.placement_types import (
@dataclass
class EinsumDims:
contracting_dims: List[str]
batch_dims: List[str]
lhs_out_only_dims: List[str]
rhs_out_only_dims: List[str]
contracting_dims: list[str]
batch_dims: list[str]
lhs_out_only_dims: list[str]
rhs_out_only_dims: list[str]
@classmethod
def parse_equation(cls, equation: str) -> tuple[List[str], str]:
def parse_equation(cls, equation: str) -> tuple[list[str], str]:
# parse einop equation and extract arg specs
"""
Parse the einsum equation str to input dim chars and output dim char
@ -37,12 +36,12 @@ class EinsumDims:
return input_dims, output_dim
@classmethod
def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims":
def parse_dims(cls, input_dims: list[str], output_dim: str) -> "EinsumDims":
"""
Parse the dims and extract the contracting, batch, and free dimensions
for the left and right hand sides.
"""
dim_char_set: Set[str] = set()
dim_char_set: set[str] = set()
for input_dim in input_dims:
dim_char_set.update(input_dim)
@ -104,7 +103,7 @@ def gen_einsum_strategies(
# placement list stores placements of [output, input1, input2, ...]
# first we always have replicate all for inputs and output
placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1)
placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
mesh_dim_strategies.append(placement_list)
# split batch dim
@ -131,7 +130,7 @@ def gen_einsum_strategies(
lhs_free_dim = output_dim.index(lhs_dim)
# this means split the lhs input and output
# i.e. S(0), R -> S(0)
lhs_placement_list: List[Placement] = [
lhs_placement_list: list[Placement] = [
Shard(lhs_free_dim),
Shard(lhs_free_dim),
Replicate(),
@ -141,7 +140,7 @@ def gen_einsum_strategies(
# split rhs free dim
for rhs_dim in edims.rhs_out_only_dims:
rhs_free_dim = output_dim.index(rhs_dim)
rhs_placement_list: List[Placement] = [
rhs_placement_list: list[Placement] = [
Shard(rhs_free_dim),
Replicate(),
Shard(rhs_free_dim),
@ -150,7 +149,7 @@ def gen_einsum_strategies(
# linearity strategy
if linearity:
linearity_placement_list: List[Placement] = [Partial()]
linearity_placement_list: list[Placement] = [Partial()]
for input_dim in input_dims:
linearity_placement_list.append(Partial())
mesh_dim_strategies.append(linearity_placement_list)

View File

@ -1,9 +1,10 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import math
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import cast, List, Optional, Sequence, Union
from typing import cast, Optional, Union
import torch
from torch.distributed.device_mesh import DeviceMesh
@ -148,11 +149,11 @@ class _NormPartial(Partial):
return 1 + hash(self.norm_type)
def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[List[int]]:
def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[list[int]]:
if dims_arg is None:
return None
dims = cast(List[int], as_list(dims_arg))
dims = cast(List[int], normalize_dims(dims, ndim))
dims = cast(list[int], as_list(dims_arg))
dims = cast(list[int], normalize_dims(dims, ndim))
empty_dims = [[0], [-1], []]
if ndim == 0 and dims_arg in empty_dims:
return None
@ -160,8 +161,8 @@ def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[List[int]]:
def _infer_reduce_dims_map(
reduction_dims: List[int], input_ndim: int, keep_dim=False
) -> List[int]:
reduction_dims: list[int], input_ndim: int, keep_dim=False
) -> list[int]:
reduction_dims_map = []
new_dim_count = 0
for input_dim in range(input_ndim):
@ -179,7 +180,7 @@ def _infer_reduce_dims_map(
def _replicate_dims_start_at(
placements: Sequence[Placement], start_dim: int = 0
) -> tuple[Placement, ...]:
new_placements: List[Placement] = []
new_placements: list[Placement] = []
for p in placements:
if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
new_placements.append(Replicate()) # make it replicate
@ -192,7 +193,7 @@ def _replicate_dims_start_at(
def _skip_dim(
placements: tuple[Placement, ...], skipped_dim: int
) -> tuple[Placement, ...]:
new_placements: List[Placement] = []
new_placements: list[Placement] = []
for p in placements:
if isinstance(p, Shard) and p.dim >= skipped_dim:
new_placements.append(Shard(p.dim - 1))
@ -202,10 +203,10 @@ def _skip_dim(
def replicate_reduction_dims(
placements: tuple[Placement, ...], reduction_dims: List[int]
placements: tuple[Placement, ...], reduction_dims: list[int]
) -> tuple[Placement, ...]:
# replicate the reduction dims if not reduction_linear
new_placements: List[Placement] = []
new_placements: list[Placement] = []
for p in placements:
if p.is_partial():
@ -220,14 +221,14 @@ def replicate_reduction_dims(
def map_placements_after_reduction(
placements: tuple[Placement, ...],
reduction_dims: List[int],
reduction_dims_map: List[int],
reduction_dims: list[int],
reduction_dims_map: list[int],
reduction_op: ReductionOpType,
) -> tuple[Placement, ...]:
"""
Map each placement based on the output shape after reduction.
"""
new_placements: List[Placement] = []
new_placements: list[Placement] = []
for placement in placements:
if isinstance(placement, (Replicate, Partial)):
new_placements.append(placement)
@ -253,7 +254,7 @@ def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement:
def common_reduction_strategy(
mesh: DeviceMesh,
input_strategy: OpStrategy,
reduce_dims: List[int],
reduce_dims: list[int],
keep_dim: bool = False,
reduction_linear: bool = True,
reduction_op: ReductionOpType = "sum",
@ -406,7 +407,7 @@ def foreach_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> TupleStrateg
assert isinstance(input_tuple_strategy, TupleStrategy)
norm_type = args_schema[1] if len(args_schema) > 1 else 2
assert isinstance(norm_type, (int, float, str)), f"{norm_type}"
output_tuple_strategy_childs: List[OpStrategy] = []
output_tuple_strategy_childs: list[OpStrategy] = []
for op_strategy in input_tuple_strategy.childs:
assert isinstance(op_strategy, OpStrategy), f"{op_strategy}"
reduce_dims = list(range(op_strategy.ndim))
@ -451,7 +452,7 @@ def linalg_replicate_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrate
args_schema = op_schema.args_schema
input_strategy = args_schema[0]
assert isinstance(input_strategy, OpStrategy), f"{input_strategy}"
output_strategies: List[PlacementStrategy] = []
output_strategies: list[PlacementStrategy] = []
for placement_strategy in input_strategy.strategies:
replicate_placements = tuple(Replicate() for _ in range(mesh.ndim))
replicate_spec = DTensorSpec(
@ -912,14 +913,14 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
axis = input_ndim - len(normalized_size)
outer_dims = list(range(axis))
assert isinstance(output_mask, List) and len(output_mask) == 3
assert isinstance(output_mask, list) and len(output_mask) == 3
# output triple: (d_input, d_weight, d_bias)
out_tuple_strategy = OpStrategy([])
for idx, input_placement_strategy in enumerate(input_strategy.strategies):
# args for PlacementStrategy
output_specs_list: List[Optional[DTensorSpec]] = []
input_specs_list: List[DTensorSpec] = []
output_specs_list: list[Optional[DTensorSpec]] = []
input_specs_list: list[DTensorSpec] = []
redistribute_costs = []
input_src_spec = input_placement_strategy.output_spec

View File

@ -1,7 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from typing import List
import torch
from torch.distributed.device_mesh import DeviceMesh
@ -415,7 +414,7 @@ def scaled_dot_product_efficient_attention_strategy(
has_attn_bias = op_schema.args_schema[3] is not None
compute_log_sumexp = op_schema.args_schema[4]
single_mesh_dim_strategies: List[PlacementList] = []
single_mesh_dim_strategies: list[PlacementList] = []
# placement list stores placements of [outputs, inputs]
# in the spda case, we have 2 valid tensor outputs and 3 or 4 tensor inputs

View File

@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import List, Sequence
from collections.abc import Sequence
import torch
from torch.distributed.device_mesh import DeviceMesh
@ -464,7 +464,7 @@ def common_pointwise_strategy(
for placement_strategy in followed_strategy.strategies:
spec_to_follow = placement_strategy.output_spec
out_placements: List[Placement] = []
out_placements: list[Placement] = []
for placement in spec_to_follow.placements:
if isinstance(placement, Shard):
shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape))
@ -479,8 +479,8 @@ def common_pointwise_strategy(
else:
out_placements.append(placement)
input_specs: List[DTensorSpec] = []
redistribute_costs: List[List[float]] = []
input_specs: list[DTensorSpec] = []
redistribute_costs: list[list[float]] = []
for input_arg in args_schema:
if isinstance(input_arg, OpStrategy):
# every arg follow the out_placements, but need to handle broadcasting
@ -616,11 +616,11 @@ def list_pointwise_strategy(
OpStrategy: generated strategy
"""
def args_tuple_strategies(args_schema: tuple[object, ...]) -> List[TupleStrategy]:
def args_tuple_strategies(args_schema: tuple[object, ...]) -> list[TupleStrategy]:
first_arg = args_schema[0]
assert isinstance(first_arg, TupleStrategy)
strategy_len = len(first_arg.childs)
tuple_strategies: List[TupleStrategy] = []
tuple_strategies: list[TupleStrategy] = []
for arg_idx, arg in enumerate(args_schema):
if isinstance(arg, TupleStrategy):
# every tuple strategy should have the same length
@ -639,10 +639,10 @@ def list_pointwise_strategy(
args_strategies = args_tuple_strategies(op_schema.args_schema)
follow_strategy: TupleStrategy = args_strategies[0]
list_strategy: List[OpStrategy] = []
list_strategy: list[OpStrategy] = []
for child_idx, child_strtgy in enumerate(follow_strategy.childs):
assert isinstance(child_strtgy, OpStrategy)
args_schema: List[StrategyType] = [
args_schema: list[StrategyType] = [
arg_strategy.childs[child_idx] for arg_strategy in args_strategies
]
pointwise_strategy: OpStrategy = common_pointwise_strategy(

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast, List, Optional, Sequence, Sized
from collections.abc import Sequence, Sized
from typing import cast, Optional
import torch
from torch.distributed.device_mesh import DeviceMesh
@ -465,7 +466,7 @@ def _derive_follow_placements_from_tuple_strategy(
# current replicate, just follow new placement
return new_placement
follow_placements: Optional[List[Placement]] = None
follow_placements: Optional[list[Placement]] = None
for arg_strategy in tuple_strategy.childs:
assert isinstance(arg_strategy, OpStrategy)
for placement_strategy in arg_strategy.strategies:
@ -489,7 +490,7 @@ def normalize_shard_for_stack(
) -> Sequence[Placement]:
# stack op would "insert" new dim, so all sharded dim >= the inserted dim need to
# be normalized with the new Shard placement
normalized_placements: List[Placement] = []
normalized_placements: list[Placement] = []
for placement in placements:
if isinstance(placement, Shard) and placement.dim >= insert_dim:
normalized_placements.append(Shard(placement.dim + 1))
@ -575,7 +576,7 @@ def prop_index_select(op_schema: OpSchema) -> OutputSharding:
assert isinstance(dim, int)
assert isinstance(indices_spec, DTensorSpec)
all_indices_spec: List[Optional[DTensorSpec]] = [
all_indices_spec: list[Optional[DTensorSpec]] = [
indices_spec if dim == i else None for i in range(values_spec.ndim)
]
@ -620,8 +621,8 @@ def prop_index(op_schema: OpSchema) -> OutputSharding:
values_spec, multi_indices_spec = op_schema.args_schema
assert isinstance(values_spec, DTensorSpec)
assert isinstance(multi_indices_spec, list)
multi_indices_spec = cast(List[Optional[DTensorSpec]], multi_indices_spec)
valid_indices_spec: List[tuple[int, DTensorSpec]] = [
multi_indices_spec = cast(list[Optional[DTensorSpec]], multi_indices_spec)
valid_indices_spec: list[tuple[int, DTensorSpec]] = [
(i, a) for i, a in enumerate(multi_indices_spec) if a is not None
]
@ -731,7 +732,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding:
schema_info=RuntimeSchemaInfo(1),
)
def split_rule(op_schema: OpSchema) -> OutputSharding:
output_spec_list: List[DTensorSpec] = []
output_spec_list: list[DTensorSpec] = []
input_spec = cast(DTensorSpec, op_schema.args_schema[0])
ndim = input_spec.ndim
split_size_or_sections = op_schema.args_schema[1]
@ -769,7 +770,7 @@ def split_rule(op_schema: OpSchema) -> OutputSharding:
),
)
def size_split(N, i) -> List:
def size_split(N, i) -> list:
# Last chunk will be smaller if the tensor size N
# along the given dimension dim is not divisible by i.
assert i > 0

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Set, Union
from typing import Callable, cast, Optional, Union
import torch
from torch import Tensor
@ -216,7 +217,7 @@ def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap:
# other dims are passed through
if end_dim < 0:
end_dim += ndim
results: List[DimSpec] = [InputDim(i) for i in range(start_dim)]
results: list[DimSpec] = [InputDim(i) for i in range(start_dim)]
results.append(
Flatten.new(tuple(InputDim(i) for i in range(start_dim, end_dim + 1)))
)
@ -414,7 +415,7 @@ def dim_unsqueeze(ndim: int, dim: int) -> DimMap:
def dim_view_as_real(shape: Shape) -> DimMap:
ndim = len(shape)
results: List[DimSpec] = [InputDim(i) for i in range(ndim - 1)]
results: list[DimSpec] = [InputDim(i) for i in range(ndim - 1)]
# each complex number is split into two real numbers,
# resulting in one more dimension of size 2
results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 0))
@ -442,7 +443,7 @@ def dim_reduction(
)
dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = {
dim_maps: dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = {
torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1),
torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2),
torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim),
@ -488,11 +489,11 @@ def propagate_shape_and_sharding(
assert len(input_src_placements) == len(mesh_sizes)
# for each input dim, for each mesh dim, provides a list of possible shardable dimensions
mesh_ndim = len(mesh_sizes)
shardable_dims: Dict[int, List[bool]] = {}
shardable_dims: dict[int, list[bool]] = {}
# in case an input dimension disappears (e.g. collapsing, reduction)
# we cannot shard in that dimension (we need a replication fall-back rule)
seen_input_dims: Set[int] = set()
seen_input_dims: set[int] = set()
def collect_used_inputs(cmd: DimSpec) -> None:
if isinstance(cmd, InputDim):

View File

@ -3,7 +3,8 @@
import functools
import itertools
import operator
from typing import Callable, cast, Iterable, List, Optional, Sequence, TypeVar, Union
from collections.abc import Iterable, Sequence
from typing import Callable, cast, Optional, TypeVar, Union
from typing_extensions import ParamSpec
import torch
@ -35,7 +36,7 @@ _P = ParamSpec("_P")
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def register_prop_rule(
op: Union[torch._ops.OpOverload, List[torch._ops.OpOverload]],
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> Callable[
[Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding]
@ -101,9 +102,9 @@ def register_op_strategy(
def as_list(
x: Union[List[object], object]
x: Union[list[object], object]
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
) -> Union[list[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
# which is an object but treated as a list by the tracer. Therefore, keep
# `immutable_list` intact here as well.
@ -178,7 +179,7 @@ def is_tensor_partial(spec: DTensorSpec) -> bool:
def infer_broadcast_dims_map(
common_shape: torch.Size, input_shape: torch.Size
) -> List[int]:
) -> list[int]:
# infer the broadcast dims map, where it maps from the common shape dim to the input shape dim
# this is aligned with the broadcast semantics
common_ndim = len(common_shape)
@ -193,10 +194,10 @@ def infer_broadcast_dims_map(
def map_placements_after_broadcast(
placements: tuple[Placement, ...],
shape: torch.Size,
broadcast_dims_map: List[int],
broadcast_dims_map: list[int],
) -> tuple[Placement, ...]:
"""Map each placement based on the output shape after broadcast."""
new_placements: List[Placement] = []
new_placements: list[Placement] = []
for placement in placements:
if isinstance(placement, (Replicate, Partial)):
new_placements.append(placement)
@ -223,8 +224,8 @@ def map_placements_after_broadcast(
def generate_redistribute_costs(
src_strategy: OpStrategy, dst_spec: DTensorSpec
) -> List[float]:
redistribute_costs: List[float] = [
) -> list[float]:
redistribute_costs: list[float] = [
redistribute_cost(strat.output_spec, dst_spec)
for strat in src_strategy.strategies
]
@ -235,7 +236,7 @@ def generate_redistribute_costs(
def expand_to_full_mesh_op_strategy(
mesh: DeviceMesh,
op_schema: OpSchema,
single_mesh_dim_strategies: List[PlacementList],
single_mesh_dim_strategies: list[PlacementList],
*,
input_index: int = 1,
inplace_op: bool = False,
@ -247,14 +248,14 @@ def expand_to_full_mesh_op_strategy(
all_strategies = []
for strategy_comb in strategy_combs:
spec_list: List[Optional[DTensorSpec]] = []
spec_list: list[Optional[DTensorSpec]] = []
for specs in zip(*strategy_comb):
if specs[0] is not None:
spec_list.append(DTensorSpec(mesh, specs))
else:
spec_list.append(None)
input_specs: List[DTensorSpec] = [
input_specs: list[DTensorSpec] = [
s for s in spec_list[input_index:] if isinstance(s, DTensorSpec)
]

View File

@ -2,7 +2,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
import warnings
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import torch
import torch.distributed as dist
@ -110,12 +110,12 @@ class _RNGStateTracker:
f"{self.__class__.__name__} instantiation requires the presence of CUDA/CUDA-like device"
)
self._states: Dict[str, Tensor] = {}
self._states: dict[str, Tensor] = {}
self._devices = [self._device_handle.current_device()]
self._use_distribute_region = True
@property
def rng_states(self) -> Dict[str, Tensor]:
def rng_states(self) -> dict[str, Tensor]:
return self._states
@property
@ -267,7 +267,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
mesh = spec.mesh
# note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP
# case. Replace the custom logic with dim_map once we support it.
dim_map: List[Union[int, List[int]]] = [-1] * spec.ndim
dim_map: list[Union[int, list[int]]] = [-1] * spec.ndim
for i, placement in enumerate(spec.placements):
if isinstance(placement, Shard):
shard_dim = placement.dim
@ -275,7 +275,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
dim_map[shard_dim] = [i]
else:
mesh_dim_list = dim_map[shard_dim]
assert isinstance(mesh_dim_list, List)
assert isinstance(mesh_dim_list, list)
mesh_dim_list.append(i)
# Compute shard coordinate:
@ -291,7 +291,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
shard_idx = 0
total_num_shards = 1
# the tensor dim is sharded on more than 1 mesh dim
if isinstance(mesh_dim, List):
if isinstance(mesh_dim, list):
rank_coord = [mesh_coordinate[d] for d in mesh_dim]
num_shards = [mesh_size[d] for d in mesh_dim]
# compute the shard idx and total number of shards
@ -356,7 +356,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
self.set_offset("parallel-rng", old_offset + numel)
def _calc_shard_linear_idx(
self, shard_coord: List[int], shard_size: List[int]
self, shard_coord: list[int], shard_size: list[int]
) -> int:
# compute shard linear index
shard_linear_idx = 0

View File

@ -1,8 +1,8 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from functools import lru_cache
from typing import cast, List, NamedTuple
from functools import cache
from typing import cast, NamedTuple
import torch
import torch.distributed._functional_collectives as funcol
@ -24,13 +24,13 @@ class _TransformInfo(NamedTuple):
mesh_dim: int
src_dst_placements: tuple[Placement, Placement]
# logical_shape on this mesh dimension
logical_shape: List[int]
logical_shape: list[int]
def _gen_transform_infos_non_cached(
src_spec: DTensorSpec,
dst_spec: DTensorSpec,
) -> List[_TransformInfo]:
) -> list[_TransformInfo]:
"""
Generate the transform infos from the source placements to the target placements.
@ -42,7 +42,7 @@ def _gen_transform_infos_non_cached(
the former is a nested-sharding of a tensor already already sharded dimension 0, whereras
the latter is the first sharding on tensor dimension 0.
"""
transform_infos: List[_TransformInfo] = []
transform_infos: list[_TransformInfo] = []
device_mesh = src_spec.device_mesh
my_coordinate = device_mesh.get_coordinate()
@ -145,11 +145,11 @@ def _gen_transform_infos_non_cached(
return transform_infos
@lru_cache(maxsize=None)
@cache
def _gen_transform_infos(
src_spec: DTensorSpec,
dst_spec: DTensorSpec,
) -> List[_TransformInfo]:
) -> list[_TransformInfo]:
return _gen_transform_infos_non_cached(src_spec, dst_spec)
@ -332,7 +332,7 @@ class Redistribute(torch.autograd.Function):
is_backward=True,
)
# normalize the target placement to replicate if it is partial
normalized_placements: List[Placement] = []
normalized_placements: list[Placement] = []
for previous_placement in previous_spec.placements:
if previous_placement.is_partial():
# keep target placement to replicate instead of partial in this case

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
import threading
from collections.abc import Sequence
from functools import lru_cache
from itertools import chain
from typing import Callable, cast, Dict, List, Optional, Sequence, Union
from typing import Callable, cast, Optional, Union
import torch
from torch._ops import OpOverload
@ -51,18 +52,18 @@ class LocalLRUCache(threading.local):
class ShardingPropagator:
def __init__(self) -> None:
self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
self.op_strategy_funcs: Dict[
self.op_to_rules: dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
self.op_strategy_funcs: dict[
OpOverload,
Callable[[DeviceMesh, OpSchema], StrategyType],
] = {}
# op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop
self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {}
self.op_to_schema_info: dict[OpOverload, RuntimeSchemaInfo] = {}
self.propagate_op_sharding = LocalLRUCache(
self.propagate_op_sharding_non_cached
)
# op map to save indices of shape (and stride) args which may need to be modified in sharding prop
self.op_to_shape_and_stride_idx: Dict[
self.op_to_shape_and_stride_idx: dict[
OpOverload, Union[int, tuple[int, int]]
] = {
# new factory ops
@ -128,7 +129,7 @@ class ShardingPropagator:
)
elif isinstance(fake_out, (tuple, list)):
tensor_meta_list: List[Optional[TensorMeta]] = []
tensor_meta_list: list[Optional[TensorMeta]] = []
for fake_out_item in fake_out:
if isinstance(fake_out_item, torch.Tensor):
tensor_meta_list.append(
@ -260,7 +261,7 @@ class ShardingPropagator:
# check if we need to redistribute the input
needs_redistribute = False
expected_input_specs: List[DTensorSpec] = []
expected_input_specs: list[DTensorSpec] = []
# in case where the op does not specify input_specs and output_specs
# is a DTensorSpec, we use output_specs as the spec for each DTensor
@ -335,8 +336,8 @@ class ShardingPropagator:
elif isinstance(op_strategy, TupleStrategy):
# tuple strategy output sharding processing
# runtime selected placement strategy for each TupleStrategy input arg
selected_strategies: List[PlacementStrategy] = []
out_spec_list: List[DTensorSpec] = []
selected_strategies: list[PlacementStrategy] = []
out_spec_list: list[DTensorSpec] = []
for strategy in op_strategy.childs:
assert isinstance(strategy, OpStrategy)
selected_strategy = self._select_strategy(strategy)
@ -344,7 +345,7 @@ class ShardingPropagator:
out_spec_list.append(selected_strategy.output_spec)
needs_redistribute = False
suggestion_args: List[object] = []
suggestion_args: list[object] = []
tensor_or_list_tensor_arg_idx = 0
for arg in op_schema.args_schema:
@ -353,7 +354,7 @@ class ShardingPropagator:
and isinstance(arg, (list, tuple))
and isinstance(arg[0], DTensorSpec)
):
expected_input_spec_list: List[DTensorSpec] = []
expected_input_spec_list: list[DTensorSpec] = []
for idx, arg_spec in enumerate(arg):
expected_input_spec = selected_strategies[idx].input_spec(
tensor_or_list_tensor_arg_idx
@ -461,7 +462,7 @@ class ShardingPropagator:
# short cut with only one possible strategy
return strategy.strategies[0]
strategy_costs: List[float] = []
strategy_costs: list[float] = []
for strtg in strategy.strategies:
assert (
strtg.redistribute_cost is not None

View File

@ -5,7 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, List
from typing import Any
import torch
from torch.distributed.checkpoint.metadata import (
@ -34,12 +34,12 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
"""
__slots__ = ["_local_shards", "_storage_meta"]
_local_shards: List[torch.Tensor]
_local_shards: list[torch.Tensor]
_storage_meta: TensorStorageMetadata
@staticmethod
def __new__(
cls, local_shards: List[torch.Tensor], local_offsets: List[tuple[int, ...]]
cls, local_shards: list[torch.Tensor], local_offsets: list[tuple[int, ...]]
) -> "LocalShardsWrapper":
assert len(local_shards) > 0
assert len(local_shards) == len(local_offsets)
@ -206,7 +206,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
[shard.requires_grad_(requires_grad) for shard in self._local_shards]
return self
def local_shards(self) -> List[torch.Tensor]:
def local_shards(self) -> list[torch.Tensor]:
"""
Returns a list of :class:`torch.Tensor' corresponding to the
local shards for this rank. Returns an empty list if the current rank
@ -214,7 +214,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
"""
return self._local_shards
def local_sizes(self) -> List[torch.Size]:
def local_sizes(self) -> list[torch.Size]:
"""
Returns a list of :class:`torch.Size' corresponding to the
local sizes for the shards on this rank. Returns an empty list if the current rank
@ -222,7 +222,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
"""
return [chunk.sizes for chunk in self._storage_meta.chunks]
def local_offsets(self) -> List[torch.Size]:
def local_offsets(self) -> list[torch.Size]:
"""
Returns a list of :class:`torch.Size' corresponding to the
local offsets for the shards on this rank. Returns an empty list if the current rank
@ -231,7 +231,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
return [chunk.offsets for chunk in self._storage_meta.chunks]
@property
def local_chunks(self) -> List[ChunkStorageMetadata]:
def local_chunks(self) -> list[ChunkStorageMetadata]:
"""
Returns a :class:`List[ChunkStorageMetadata]` object corresponding to the
metadata for each tensor shard
@ -247,7 +247,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
def __create_write_items__(
self, fqn: str, object: Any
) -> List[WriteItem]: # pyre-ignore[2]
) -> list[WriteItem]: # pyre-ignore[2]
"""
For compatibility with DCP, we support creation of WriteItems
such that they can be saved properly.
@ -268,7 +268,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
for tensor, chunks in zip(self.local_shards(), self.local_chunks)
]
def __create_chunk_list__(self) -> List[ChunkStorageMetadata]:
def __create_chunk_list__(self) -> list[ChunkStorageMetadata]:
"""
For compatibility with DCP, we support creation of chunk lists
such that they can be saved properly.

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from typing import cast, Dict, List
from typing import cast
import torch
import torch.distributed as dist
@ -106,7 +106,7 @@ def _ring_send_recv_aggregate(grad_in_tensor, d1, d2, left, right, rank, size):
def tp_convolution(
op_call: torch._ops.OpOverload,
local_tensor_args: tuple[object, ...],
local_tensor_kwargs: Dict[str, object],
local_tensor_kwargs: dict[str, object],
) -> object:
assert op_call == aten.convolution.default
assert len(local_tensor_args) == 9
@ -118,7 +118,7 @@ def tp_convolution(
stride, padding, dilation = local_tensor_args[3:6]
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
assert isinstance(padding, List)
assert isinstance(padding, list)
if not _requires_data_exchange(padding):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
@ -159,7 +159,7 @@ def tp_convolution(
def tp_convolution_backward(
op_call: torch._ops.OpOverload,
local_tensor_args: tuple[object, ...],
local_tensor_kwargs: Dict[str, object],
local_tensor_kwargs: dict[str, object],
) -> object:
assert op_call == aten.convolution_backward.default
assert len(local_tensor_args) == 11
@ -172,7 +172,7 @@ def tp_convolution_backward(
stride, padding, dilation = local_tensor_args[4:7]
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
assert isinstance(padding, List)
assert isinstance(padding, list)
if not _requires_data_exchange(padding):
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
@ -229,7 +229,7 @@ def tp_convolution_backward(
def convolution_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> object:
# extract local tensor and sharding infos to a OpInfo
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
@ -252,7 +252,7 @@ def convolution_handler(
def convolution_backward_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> object:
# Redistribute grad_output tensor to the same placement as input tensor
args = list(args)

View File

@ -1,4 +1,5 @@
from typing import cast, List, Sequence
from collections.abc import Sequence
from typing import cast
import torch
import torch.distributed.tensor._api as dtensor
@ -165,7 +166,7 @@ def compute_local_shape_and_global_offset(
def compute_global_tensor_info(
tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement]
) -> tuple[List[int], List[int]]:
) -> tuple[list[int], list[int]]:
"""
Compute the global size and stride of a DTensor from the given local tensor.
The local size is multiplited by `world_size` per Sharding dim.

View File

@ -4,7 +4,7 @@ import json
import re
import weakref
from collections import defaultdict
from typing import Any, Dict
from typing import Any
import torch
import torch.nn
@ -240,7 +240,7 @@ class CommDebugMode(TorchDispatchMode):
"""
def __init__(self):
self.comm_counts: Dict[Any, int] = defaultdict(int)
self.comm_counts: dict[Any, int] = defaultdict(int)
self.comm_module_counts = {}
self.comm_module_operation_counts = {}
self.comm_registry = set()
@ -392,7 +392,7 @@ class CommDebugMode(TorchDispatchMode):
return json_dict
json_dict: Dict[str, Any] = {}
json_dict: dict[str, Any] = {}
add_json_information(json_dict, "Global")
# converts dictonary into json file
@ -567,7 +567,7 @@ class CommDebugMode(TorchDispatchMode):
def get_total_counts(self) -> int:
return sum(self.comm_counts.values())
def get_comm_counts(self) -> Dict[Any, int]:
def get_comm_counts(self) -> dict[Any, int]:
"""Returns the communication counts as a dictionary.
Returns:
@ -575,10 +575,10 @@ class CommDebugMode(TorchDispatchMode):
"""
return self.comm_counts
def get_parameter_info(self) -> Dict[str, Dict[str, Any]]:
def get_parameter_info(self) -> dict[str, dict[str, Any]]:
return self.advanced_module_tracker.module_parameters_dict
def get_sharding_info(self) -> Dict[str, Dict[str, Any]]:
def get_sharding_info(self) -> dict[str, dict[str, Any]]:
return self.advanced_module_tracker.sharding_dict
def __enter__(self):

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
from operator import itemgetter
from typing import List
import torch
import torch.fx
@ -13,7 +12,7 @@ from torch.distributed.tensor import DTensor
inductor_decomps = select_decomp_table()
graphs: List[torch.fx.GraphModule] = []
graphs: list[torch.fx.GraphModule] = []
def fwd_bwd_compiler(fx_g, _):

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import List, Sequence
from collections.abc import Sequence
import numpy as np
@ -89,7 +89,7 @@ def _compute_local_shape_and_global_offset(
global_shape: ShapeType,
mesh: DeviceMesh,
placements: Sequence[Placement],
my_coordinate: List[int],
my_coordinate: list[int],
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""
Same as torch.distributed._tensor._utils.compute_local_shape_and_global_offset but

View File

@ -4,7 +4,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.p
"""
import argparse
import os
from typing import Callable, Dict, Union
from typing import Callable, Union
import torch
import torch.nn as nn
@ -713,7 +713,7 @@ def run_example(world_size: int, rank: int, example_name: str) -> None:
# intializing class with all of the functions
instantiated_example = CommDebugModeExample(world_size, rank)
# dict that stores example code function names
name_to_example_code: Dict[str, Callable[[], None]] = {
name_to_example_code: dict[str, Callable[[], None]] = {
"MLP_distributed_sharding_display": instantiated_example.example_MLP_distributed_sharding_display,
"MLPStacked_distributed_sharding_display": instantiated_example.example_MLPStacked_distributed_sharding_display,
"MLP_module_tracing": instantiated_example.example_MLP_module_tracing,

View File

@ -6,7 +6,7 @@ sharding with the DTensor API.
import argparse
import os
from functools import cached_property
from typing import List, TYPE_CHECKING
from typing import TYPE_CHECKING
import torch
from torch.distributed.checkpoint.metadata import (
@ -43,12 +43,12 @@ supported_ops = [aten.view.default, aten._to_copy.default]
# this torch.Tensor subclass is a wrapper around all local shards associated
# with a single sharded embedding table.
class LocalShardsWrapper(torch.Tensor):
local_shards: List[torch.Tensor]
local_shards: list[torch.Tensor]
storage_meta: TensorStorageMetadata
@staticmethod
def __new__(
cls, local_shards: List[torch.Tensor], offsets: List[torch.Size]
cls, local_shards: list[torch.Tensor], offsets: list[torch.Size]
) -> "LocalShardsWrapper":
assert len(local_shards) > 0
assert len(local_shards) == len(offsets)
@ -98,19 +98,19 @@ class LocalShardsWrapper(torch.Tensor):
)
@property
def shards(self) -> List[torch.Tensor]:
def shards(self) -> list[torch.Tensor]:
return self.local_shards
@shards.setter
def shards(self, local_shards: List[torch.Tensor]):
def shards(self, local_shards: list[torch.Tensor]):
self.local_shards = local_shards
@cached_property
def shard_sizes(self) -> List[torch.Size]:
def shard_sizes(self) -> list[torch.Size]:
return [chunk.sizes for chunk in self.storage_meta.chunks]
@cached_property
def shard_offsets(self) -> List[torch.Size]:
def shard_offsets(self) -> list[torch.Size]:
return [chunk.offsets for chunk in self.storage_meta.chunks]
@ -158,7 +158,7 @@ def run_torchrec_row_wise_even_sharding_example(rank, world_size):
# this is the sharding placement we use in DTensor to represent row-wise sharding
# row_wise_sharding_placements means that the global tensor is sharded by first dim
# over the 1-d mesh.
row_wise_sharding_placements: List[Placement] = [Shard(0)]
row_wise_sharding_placements: list[Placement] = [Shard(0)]
# create a DTensor from the local shard
dtensor = DTensor.from_local(
@ -226,7 +226,7 @@ def run_torchrec_row_wise_uneven_sharding_example(rank, world_size):
###########################################################################
# example 1: transform local_shards into DTensor
# create the DTensorMetadata which torchrec should provide
row_wise_sharding_placements: List[Placement] = [Shard(0)]
row_wise_sharding_placements: list[Placement] = [Shard(0)]
# note: for uneven sharding, we need to specify the shape and stride because
# DTensor would assume even sharding and compute shape/stride based on the

View File

@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Iterator
from torch.distributed.tensor._api import DTensor
from torch.distributed.tensor.experimental._attention import context_parallel

View File

@ -6,9 +6,10 @@ import logging
import types
import weakref
from abc import ABC, abstractmethod
from collections.abc import Generator
from dataclasses import dataclass
from enum import auto, Enum
from typing import Any, Callable, Dict, Generator, List, Optional, Protocol, Set, Union
from typing import Any, Callable, Optional, Protocol, Union
import torch
import torch.distributed as dist
@ -435,7 +436,7 @@ def _templated_ring_attention(
raise RuntimeError("Load balancing requires `is_causal=True`.")
if isinstance(mesh, dist.ProcessGroup):
pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = mesh
pg: Union[dist.ProcessGroup, list[dist.ProcessGroup]] = mesh
else:
pg = mesh.get_group()
assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension"
@ -452,7 +453,7 @@ def _templated_ring_attention(
sdpa_merger = _SDPAMerger(_cp_options.convert_to_f32, seq_dim=seq_dim)
rest: List[Any]
rest: list[Any]
out: torch.Tensor
logsumexp: torch.Tensor
@ -519,7 +520,7 @@ def _templated_ring_attention(
def _sdpa_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> object:
# extract local tensor and sharding infos to a OpInfo
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
@ -557,7 +558,7 @@ def _sdpa_handler(
def _sdpa_backward_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> object:
# Redistribute grad_output tensor to the same placement as output tensor
args = list(args)
@ -614,7 +615,7 @@ def _templated_ring_attention_backward(
size = dist.get_world_size(pg)
next_kv = None
next_grad_kv = None
rest: List[Any]
rest: list[Any]
grad_query_, grad_key_, grad_value_ = None, None, None
accum_dtype = torch.float32 if _cp_options.convert_to_f32 else query.dtype
@ -850,7 +851,7 @@ customized_ops = {
}
_replaced_functions: Dict[Callable, tuple[str, Callable]] = {}
_replaced_functions: dict[Callable, tuple[str, Callable]] = {}
def _distribute_function(
@ -890,7 +891,7 @@ def _distribute_function(
def wrapper(
target_fn: Callable, input_fn: Optional[Callable], output_fn: Optional[Callable]
) -> Callable:
def inner_fn(*args: tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any:
def inner_fn(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any:
if input_fn is not None:
args, kwargs = input_fn(device_mesh, *args, **kwargs)
output = target_fn(*args, **kwargs)
@ -1045,8 +1046,8 @@ def _context_parallel(seq_dim: int, mesh: DeviceMesh) -> Generator[None, None, N
"""Replace SDPA with the CP-wrapped version and enable DTensor CP dispatcher."""
def attention_input_fn(
mesh: DeviceMesh, *args: tuple[Any, ...], **kwargs: Dict[str, Any]
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
mesh: DeviceMesh, *args: tuple[Any, ...], **kwargs: dict[str, Any]
) -> tuple[tuple[Any, ...], dict[str, Any]]:
placement = [Shard(seq_dim)]
all_args = []
@ -1180,9 +1181,9 @@ class _RoundRobinLoadBalancer(_LoadBalancer):
def _context_parallel_buffers(
mesh: DeviceMesh,
buffers: List[torch.Tensor],
buffer_seq_dims: List[int],
) -> List[torch.Tensor]:
buffers: list[torch.Tensor],
buffer_seq_dims: list[int],
) -> list[torch.Tensor]:
"""Shard the buffers along the sequence dimensions according to CP rules."""
new_buffers = []
sharder = (
@ -1201,9 +1202,9 @@ def _context_parallel_buffers(
def context_parallel(
mesh: DeviceMesh,
*,
buffers: Optional[List[torch.Tensor]] = None,
buffer_seq_dims: Optional[List[int]] = None,
no_restore_buffers: Optional[Set[torch.Tensor]] = None,
buffers: Optional[list[torch.Tensor]] = None,
buffer_seq_dims: Optional[list[int]] = None,
no_restore_buffers: Optional[set[torch.Tensor]] = None,
) -> Generator[None, None, None]:
"""
@ -1267,9 +1268,9 @@ def context_parallel(
@torch.no_grad()
def context_parallel_unshard(
mesh: DeviceMesh,
buffers: List[torch.Tensor],
seq_dims: List[int],
) -> List[torch.Tensor]:
buffers: list[torch.Tensor],
seq_dims: list[int],
) -> list[torch.Tensor]:
"""
Unshard the tensors (e.g., output) that are sharded due to context parallelism.

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Callable, Optional, Sequence, Union
from collections.abc import Sequence
from typing import Callable, Optional, Union
import torch
from torch.distributed._functional_collectives import AsyncCollectiveTensor

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from collections.abc import Sequence
from functools import partial
from typing import Callable, List, Sequence, Union
from typing import Callable, Union
import torch
from torch._ops import OpOverload
@ -21,7 +22,7 @@ from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy
__all__ = ["register_sharding"]
def register_sharding(op: Union[OpOverload, List[OpOverload]]):
def register_sharding(op: Union[OpOverload, list[OpOverload]]):
"""
:meth:`register_sharding` is an experimental API that allows users to register sharding
strategies for an operator when the tensor inputs and outputs are DTensor.
@ -89,7 +90,7 @@ def register_sharding(op: Union[OpOverload, List[OpOverload]]):
acceptable_shardings = custom_sharding_fn(*args_schema, **kwargs_schema)
single_mesh_dim_strategies: List[PlacementList] = []
single_mesh_dim_strategies: list[PlacementList] = []
for output_specs, input_specs in acceptable_shardings:
single_mesh_dim_strategies.append(output_specs + input_specs)
@ -110,7 +111,7 @@ def register_sharding(op: Union[OpOverload, List[OpOverload]]):
# 2. let static_kwargkey include all the int type kwargs
# 3. always set needs_pytree=True
static_argnum = 100
static_kwargkey: List[str] = []
static_kwargkey: list[str] = []
for i, arg in enumerate(op._schema.arguments):
if isinstance(arg.type, torch.IntType) or (
isinstance(arg.type, torch.OptionalType)

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import copy
import operator
from typing import Any, cast, Dict, List, Optional, Sequence
from collections.abc import Sequence
from typing import Any, cast, Optional
import torch
from torch._subclasses.fake_tensor import FakeTensor
@ -36,7 +37,7 @@ def tensor_parallel_transformation(
rank: int,
world_size: int,
device_type: str,
parallel_strategies: Dict[str, ParallelStyle],
parallel_strategies: dict[str, ParallelStyle],
) -> ExportedProgram:
"""
The entry point function to perform graph transformations on an exported program
@ -77,14 +78,14 @@ class _TensorParallelTransformPass(PassBase):
rank: int,
world_size: int,
device_type: str,
state_dict: Dict[str, torch.Tensor],
state_dict: dict[str, torch.Tensor],
graph_signature: ExportGraphSignature,
parallel_strategies: Dict[str, ParallelStyle],
parallel_strategies: dict[str, ParallelStyle],
) -> None:
super().__init__()
self.rank = rank
self.mesh = DeviceMesh(device_type, torch.arange(world_size))
self.state_dict: Dict[str, torch.Tensor] = state_dict
self.state_dict: dict[str, torch.Tensor] = state_dict
self.graph_signature = graph_signature
self.parallel_strategies = parallel_strategies
@ -105,13 +106,13 @@ class _TensorParallelTransformPass(PassBase):
def _generate_parameter_and_buffer_placements(
params_and_buffers: List[str],
parallel_strategies: Dict[str, ParallelStyle],
) -> Dict[str, Placement]:
params_and_buffers: list[str],
parallel_strategies: dict[str, ParallelStyle],
) -> dict[str, Placement]:
"""
Build parameter placements based on the give parallel style of linear layers.
"""
parameter_placements: Dict[str, Placement] = {}
parameter_placements: dict[str, Placement] = {}
for linear_fqn, parallel_style in parallel_strategies.items():
weight_fqn = f"{linear_fqn}.weight"
bias_fqn = f"{linear_fqn}.bias"
@ -130,12 +131,12 @@ def _mark_tensor_parallel_shardings(
gm: GraphModule,
graph_signature: ExportGraphSignature,
mesh: DeviceMesh,
parameter_placements: Dict[str, Placement],
) -> Dict[Node, PlacementStrategy]:
parameter_placements: dict[str, Placement],
) -> dict[Node, PlacementStrategy]:
"""
Mark the placement strategies of the parameter and buffer placeholder nodes.
"""
placement_strategies: Dict[Node, PlacementStrategy] = {}
placement_strategies: dict[Node, PlacementStrategy] = {}
num_params_and_buffers = len(graph_signature.inputs_to_parameters) + len(
graph_signature.inputs_to_buffers
)
@ -182,12 +183,12 @@ def _mark_sharding(
gm: GraphModule,
graph_signature: ExportGraphSignature,
mesh: DeviceMesh,
parameter_placements: Dict[str, Placement],
) -> Dict[Node, PlacementStrategy]:
parameter_placements: dict[str, Placement],
) -> dict[Node, PlacementStrategy]:
"""
Mark the sharding strategy for each node in the graph module.
"""
placement_strategies: Dict[
placement_strategies: dict[
Node, PlacementStrategy
] = _mark_tensor_parallel_shardings(gm, graph_signature, mesh, parameter_placements)
@ -485,12 +486,12 @@ def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None:
def _get_input_node_specs(
node: Node, placement_strategies: Dict[Node, PlacementStrategy]
node: Node, placement_strategies: dict[Node, PlacementStrategy]
) -> tuple[DTensorSpec, ...]:
"""
Get the input specs of a node.
"""
input_specs_list: List[DTensorSpec] = []
input_specs_list: list[DTensorSpec] = []
for input_arg in node.all_input_nodes:
if input_arg in placement_strategies:
output_spec = placement_strategies[input_arg].output_specs
@ -502,7 +503,7 @@ def _get_input_node_specs(
def _get_op_schema(
node: Node, placement_strategies: Dict[Node, PlacementStrategy]
node: Node, placement_strategies: dict[Node, PlacementStrategy]
) -> OpSchema:
"""
Util function to construct the operator schema of a node.
@ -513,14 +514,14 @@ def _get_op_schema(
op_schema = OpSchema(
op=cast(torch._ops.OpOverload, node.target),
args_schema=tuple(args_schema_list),
kwargs_schema=cast(Dict[str, object], node.kwargs),
kwargs_schema=cast(dict[str, object], node.kwargs),
)
return op_schema
def _shard_state_dict(
state_dict: Dict[str, torch.Tensor],
placement_strategies: Dict[Node, PlacementStrategy],
state_dict: dict[str, torch.Tensor],
placement_strategies: dict[Node, PlacementStrategy],
graph_signature: ExportGraphSignature,
mesh: DeviceMesh,
) -> None:

View File

@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import warnings
from fnmatch import fnmatch
from typing import Dict, Optional, Union
from typing import Optional, Union
import torch
import torch.nn as nn
@ -16,7 +16,7 @@ __all__ = ["parallelize_module"]
def parallelize_module( # type: ignore[return]
module: nn.Module,
device_mesh: Optional[DeviceMesh] = None,
parallelize_plan: Optional[Union[ParallelStyle, Dict[str, ParallelStyle]]] = None,
parallelize_plan: Optional[Union[ParallelStyle, dict[str, ParallelStyle]]] = None,
*,
src_data_rank: Optional[int] = 0,
) -> nn.Module:

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, List, Optional, Set
from typing import Any, Optional
import torch.nn as nn
from torch.distributed.tensor.parallel._data_parallel_utils import (
@ -23,7 +23,7 @@ def _get_submodule_n_params(module: nn.Module, path: str):
return module, path
def _update_module_param(param_list: List[tuple[nn.Module, str, nn.Parameter]]):
def _update_module_param(param_list: list[tuple[nn.Module, str, nn.Parameter]]):
"""
Update parameters within the module
"""
@ -48,7 +48,7 @@ def _reconstruct_dtensor(module: nn.Module, _input: Any):
def _localize_dtensor(
module: nn.Module, *_: Any, ignored_params: Optional[Set[nn.Parameter]] = None
module: nn.Module, *_: Any, ignored_params: Optional[set[nn.Parameter]] = None
):
"""
Convert DTensor parameters to local tensors

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import copy
from typing import Any, cast, List, Optional
from typing import Any, cast, Optional
import torch
import torch.distributed as dist
@ -163,7 +163,7 @@ def _chunk_tensor(
)
outer_local_shard = tensor.local_shards()[0]
shards: List[Shard] = [
shards: list[Shard] = [
Shard(inner_st, copy.deepcopy(outer_local_shard.metadata))
]
st_meta = copy.deepcopy(tensor.metadata())
@ -284,7 +284,7 @@ def _chunk_dtensor(
def _pre_load_state_dict(
tensor: torch.Tensor,
) -> tuple[torch.Tensor, List[Shard]]:
) -> tuple[torch.Tensor, list[Shard]]:
shards = cast(ShardedTensor, tensor).local_shards()
if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor:
inner_tensor = shards[0].tensor
@ -377,7 +377,7 @@ class DTensorExtensions(FSDPExtensions):
def pre_load_state_dict_transform(
self,
tensor: torch.Tensor,
) -> tuple[torch.Tensor, List[Shard]]:
) -> tuple[torch.Tensor, list[Shard]]:
return _pre_load_state_dict(tensor)
def all_gather_dtensor(

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import contextlib
from typing import cast, Dict, Optional
from typing import cast, Optional
import torch
import torch._prims_common as utils
@ -108,7 +108,7 @@ def _cast_to_dtensor(
def _propagate_tensor_meta(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> TensorMeta:
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta(
@ -154,7 +154,7 @@ def _log_softmax(x, dim, half_to_float, mesh, mesh_dim):
def _log_softmax_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> object:
x = cast(DTensor, args[0])
dim = cast(int, args[1])
@ -185,7 +185,7 @@ def _log_softmax_handler(
def _log_softmax_backward_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> object:
grad_output = cast(DTensor, args[0])
input_dtype = cast(torch.dtype, args[3])
@ -270,7 +270,7 @@ def _nll_loss_forward(
def _nll_loss_forward_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> object:
x = cast(DTensor, args[0])
target = args[1]
@ -414,7 +414,7 @@ def _nll_loss_and_log_softmax_backward(
def _nll_loss_backward_handler(
op_call: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> object:
grad_output = cast(DTensor, args[0])
x = cast(DTensor, args[1])

View File

@ -2,7 +2,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
import torch
import torch.nn as nn
@ -453,8 +453,8 @@ class PrepareModuleInput(ParallelStyle):
desired_input_layouts: Optional[
Union[Placement, tuple[Optional[Placement]]]
] = None,
input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
input_kwarg_layouts: Optional[dict[str, Placement]] = None,
desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None,
use_local_output: bool = False,
):
self.input_layouts = (

View File

@ -2,7 +2,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from dataclasses import dataclass
from typing import cast, List, Optional
from typing import cast, Optional
import torch
import torch.distributed._functional_collectives as funcol
@ -73,7 +73,7 @@ class Shard(Placement):
*,
with_padding: bool = True,
contiguous: bool = True,
) -> tuple[List[torch.Tensor], List[int]]:
) -> tuple[list[torch.Tensor], list[int]]:
"""
This function uses torch.chunk to split a tensor into num_chunks shards along
the Shard placement dimension, and return a list of shards with their pad sizes.
@ -236,7 +236,7 @@ class Shard(Placement):
local_tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
current_logical_shape: List[int],
current_logical_shape: list[int],
) -> torch.Tensor:
"""
This function all_gather all shards and return a tensor that
@ -292,7 +292,7 @@ class Shard(Placement):
local_tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
current_logical_shape: List[int],
current_logical_shape: list[int],
new_shard_dim: int,
) -> torch.Tensor:
"""
@ -464,7 +464,7 @@ class _StridedShard(Shard):
*,
with_padding: bool = True,
contiguous: bool = True,
) -> tuple[List[torch.Tensor], List[int]]:
) -> tuple[list[torch.Tensor], list[int]]:
"""
TODO: currently _StridedShard does not support padding
"""
@ -502,7 +502,7 @@ class _StridedShard(Shard):
local_tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
current_logical_shape: List[int],
current_logical_shape: list[int],
) -> torch.Tensor:
"""
Note: currently _StridedShard does not support padding