mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
PEP585 update - torch/distributed/tensor (#145141)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145141 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
4f8237dbad
commit
c95efc37ba
|
|
@ -3,7 +3,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, cast, Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, cast, Optional
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
|
@ -73,7 +73,7 @@ def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim):
|
|||
|
||||
def mesh_scatter(
|
||||
output: torch.Tensor,
|
||||
scatter_list: List[torch.Tensor],
|
||||
scatter_list: list[torch.Tensor],
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int = 0,
|
||||
async_op: bool = False,
|
||||
|
|
@ -195,8 +195,8 @@ def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Ten
|
|||
|
||||
|
||||
def fill_empty_tensor_to_shards(
|
||||
shards: List[torch.Tensor], shard_dim: int, num_empty_tensors: int
|
||||
) -> List[torch.Tensor]:
|
||||
shards: list[torch.Tensor], shard_dim: int, num_empty_tensors: int
|
||||
) -> list[torch.Tensor]:
|
||||
if num_empty_tensors == 0:
|
||||
return shards
|
||||
tensor_size = list(shards[0].size())
|
||||
|
|
@ -244,9 +244,9 @@ class MeshTopoInfo:
|
|||
"""
|
||||
|
||||
mesh: DeviceMesh
|
||||
mesh_dim_devices: List[int]
|
||||
mesh_dim_bandwidth: List[float]
|
||||
mesh_dim_latency: List[float]
|
||||
mesh_dim_devices: list[int]
|
||||
mesh_dim_bandwidth: list[float]
|
||||
mesh_dim_latency: list[float]
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(None)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ import functools
|
|||
import logging
|
||||
import operator
|
||||
import warnings
|
||||
from typing import cast, Dict, List, Optional, Sequence, TYPE_CHECKING
|
||||
from collections.abc import Sequence
|
||||
from typing import cast, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -44,7 +45,7 @@ logger = logging.getLogger(__name__)
|
|||
def decompose_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
"""
|
||||
Decomposes a op to core ATen op, this handler is mostly here
|
||||
|
|
@ -60,7 +61,7 @@ def decompose_handler(
|
|||
def is_same_size_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> bool:
|
||||
lhs = cast(torch.Tensor, args[0])
|
||||
rhs = cast(torch.Tensor, args[1])
|
||||
|
|
@ -70,11 +71,11 @@ def is_same_size_handler(
|
|||
def found_inf_reduce_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> None:
|
||||
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
|
||||
local_tensor_args = pytree.tree_unflatten(
|
||||
cast(List[object], op_info.local_args), op_info.args_tree_spec # type: ignore[arg-type]
|
||||
cast(list[object], op_info.local_args), op_info.args_tree_spec # type: ignore[arg-type]
|
||||
)
|
||||
local_tensor_args = cast(tuple[object, ...], local_tensor_args)
|
||||
op_call(*local_tensor_args, **op_info.local_kwargs)
|
||||
|
|
@ -154,7 +155,7 @@ class OpDispatcher:
|
|||
self,
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
"""
|
||||
Main dispatching logic
|
||||
|
|
@ -185,7 +186,7 @@ class OpDispatcher:
|
|||
|
||||
local_tensor_args = (
|
||||
pytree.tree_unflatten(
|
||||
cast(List[object], op_info.local_args), op_info.args_tree_spec
|
||||
cast(list[object], op_info.local_args), op_info.args_tree_spec
|
||||
)
|
||||
if op_info.args_tree_spec
|
||||
else op_info.local_args
|
||||
|
|
@ -251,7 +252,7 @@ class OpDispatcher:
|
|||
local_results = [
|
||||
default_tensor(s) if s is not None else None for s in spec
|
||||
]
|
||||
assert isinstance(local_results, List)
|
||||
assert isinstance(local_results, list)
|
||||
if None in local_results:
|
||||
ret_type = str(ret_list[0].type)
|
||||
raise NotImplementedError(
|
||||
|
|
@ -309,7 +310,7 @@ class OpDispatcher:
|
|||
else:
|
||||
flatten_args_schema_to_reshard = suggested_input_schema.args_schema
|
||||
|
||||
new_local_args: List[object] = []
|
||||
new_local_args: list[object] = []
|
||||
for i, arg_spec in enumerate(op_info.flat_args_schema):
|
||||
reshard_arg_spec = flatten_args_schema_to_reshard[i]
|
||||
if isinstance(arg_spec, DTensorSpec):
|
||||
|
|
@ -330,7 +331,7 @@ class OpDispatcher:
|
|||
self,
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> OpInfo:
|
||||
# get runtime schema info to determine whether to use pytree to flatten inputs
|
||||
runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
|
||||
|
|
@ -344,10 +345,10 @@ class OpDispatcher:
|
|||
else:
|
||||
args_list, args_spec = args, None
|
||||
|
||||
args_schema: List[object] = []
|
||||
kwargs_schema: Dict[str, object] = {}
|
||||
local_args: List[object] = []
|
||||
local_kwargs: Dict[str, object] = {}
|
||||
args_schema: list[object] = []
|
||||
kwargs_schema: dict[str, object] = {}
|
||||
local_args: list[object] = []
|
||||
local_kwargs: dict[str, object] = {}
|
||||
mesh: Optional[DeviceMesh] = None
|
||||
|
||||
for arg in args_list:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, cast, List, NamedTuple, Optional
|
||||
from typing import Any, cast, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
|
@ -133,7 +133,7 @@ class DTensorSpec:
|
|||
return self.mesh
|
||||
|
||||
@property
|
||||
def dim_map(self) -> List[int]:
|
||||
def dim_map(self) -> list[int]:
|
||||
"""
|
||||
dim_map is a property we derive from `placements` of
|
||||
the distributed tensor. It simply return a list of ints
|
||||
|
|
@ -170,7 +170,7 @@ class DTensorSpec:
|
|||
return r
|
||||
|
||||
@property
|
||||
def num_shards_map(self) -> List[int]:
|
||||
def num_shards_map(self) -> list[int]:
|
||||
"""
|
||||
dim_map is a property we derive from `placements` of
|
||||
the distributed tensor. Unlike `dim_map`, `num_shards_map`
|
||||
|
|
@ -193,7 +193,7 @@ class DTensorSpec:
|
|||
return r
|
||||
|
||||
@property
|
||||
def sums(self) -> List[int]:
|
||||
def sums(self) -> list[int]:
|
||||
"""
|
||||
sums is a property we derive from `placements` of the
|
||||
distributed tensor. It simply return a list of ints where
|
||||
|
|
@ -209,8 +209,8 @@ class DTensorSpec:
|
|||
def from_dim_map(
|
||||
cls,
|
||||
mesh: DeviceMesh,
|
||||
dim_map: List[int],
|
||||
sums: List[int],
|
||||
dim_map: list[int],
|
||||
sums: list[int],
|
||||
tensor_meta: Optional[TensorMeta] = None,
|
||||
) -> "DTensorSpec":
|
||||
"""
|
||||
|
|
@ -228,7 +228,7 @@ class DTensorSpec:
|
|||
a class:`DTensorSpec` object
|
||||
"""
|
||||
# by default replicate on device mesh dims
|
||||
placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)]
|
||||
placements: list[Placement] = [Replicate() for _ in range(mesh.ndim)]
|
||||
|
||||
# find all mesh dims that need pending reductions
|
||||
for s in sums:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverload
|
||||
|
|
@ -22,9 +23,9 @@ except ImportError:
|
|||
|
||||
# Common type aliases
|
||||
ArgsType = tuple[object, ...]
|
||||
KwargsType = Dict[str, object]
|
||||
KwargsType = dict[str, object]
|
||||
|
||||
PlacementList = List[Optional[Placement]]
|
||||
PlacementList = list[Optional[Placement]]
|
||||
|
||||
# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
|
||||
# be the same set of possibilities.
|
||||
|
|
@ -86,7 +87,7 @@ class PlacementStrategy:
|
|||
# we need a nested list to record the cost for each
|
||||
# operand of this operator, and for each operand of
|
||||
# this operator it might have multiple placement strategies
|
||||
redistribute_cost: Optional[List[List[float]]] = None
|
||||
redistribute_cost: Optional[list[list[float]]] = None
|
||||
|
||||
@cached_property
|
||||
def output_spec(self) -> DTensorSpec:
|
||||
|
|
@ -130,9 +131,9 @@ class OpStrategy(StrategyType):
|
|||
OpStrategy that consists of a list of placement strategies associated with the op
|
||||
"""
|
||||
|
||||
def __init__(self, strategies: List[PlacementStrategy]) -> None:
|
||||
def __init__(self, strategies: list[PlacementStrategy]) -> None:
|
||||
super().__init__()
|
||||
self.strategies: List[PlacementStrategy] = strategies
|
||||
self.strategies: list[PlacementStrategy] = strategies
|
||||
|
||||
def __str__(self) -> str:
|
||||
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
|
||||
|
|
@ -203,7 +204,7 @@ class RuntimeSchemaInfo:
|
|||
# Note that only a few ops need this information, e.g. view, transpose, var.dim, etc.
|
||||
static_argnum: int = 100
|
||||
# This static_kwargkey records static kwarg names which would affect sharding prop
|
||||
static_kwargkey: Optional[List[str]] = None
|
||||
static_kwargkey: Optional[list[str]] = None
|
||||
# each op can decide if it wants to use pytree flatten/unflatten during operator
|
||||
# eager execution, by default we don't need to do flatten/unflatten, only if the
|
||||
# op indicate it needs to, this is to accelerate eager performance.
|
||||
|
|
@ -270,7 +271,7 @@ class OpSchema:
|
|||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
args_schema: List[str] = []
|
||||
args_schema: list[str] = []
|
||||
mesh_shape = None
|
||||
for arg in self.args_schema:
|
||||
if isinstance(arg, DTensorSpec):
|
||||
|
|
@ -404,7 +405,7 @@ class OpSchema:
|
|||
|
||||
def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
|
||||
suggestion_args_spec = self.args_spec
|
||||
new_arg_schema: List[object] = []
|
||||
new_arg_schema: list[object] = []
|
||||
idx_of_args_spec = 0
|
||||
if (
|
||||
origin_schema.schema_info is not None
|
||||
|
|
@ -448,9 +449,9 @@ class OpInfo:
|
|||
|
||||
mesh: DeviceMesh
|
||||
schema: OpSchema
|
||||
flat_args_schema: List[object]
|
||||
flat_args_schema: list[object]
|
||||
local_args: Sequence[object]
|
||||
local_kwargs: Dict[str, object]
|
||||
local_kwargs: dict[str, object]
|
||||
args_tree_spec: Optional[TreeSpec] = None
|
||||
|
||||
# the output sharding info
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import string
|
||||
from typing import cast, Dict, List, Optional
|
||||
from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
|
|
@ -20,12 +20,12 @@ def _replace_char_in_str(string: str, new_char: str, idx: int) -> str:
|
|||
|
||||
def _gen_reshard_suggestions(
|
||||
op_schema: OpSchema,
|
||||
input_dims: List[str],
|
||||
input_dims: list[str],
|
||||
input_specs: tuple[DTensorSpec, ...],
|
||||
dim_to_sharding: Dict[str, int],
|
||||
pending_sum: List[int],
|
||||
dim_to_sharding: dict[str, int],
|
||||
pending_sum: list[int],
|
||||
) -> OutputSharding:
|
||||
suggested_arg_specs: List[DTensorSpec] = []
|
||||
suggested_arg_specs: list[DTensorSpec] = []
|
||||
for input_dim, input_spec in zip(input_dims, input_specs):
|
||||
dim_map = [dim_to_sharding[dim] for dim in input_dim]
|
||||
suggested_arg_specs.append(
|
||||
|
|
@ -49,7 +49,7 @@ def einop_rule(
|
|||
op_schema: OpSchema,
|
||||
*,
|
||||
linearity: bool = False,
|
||||
enforce_sharding: Optional[Dict[str, int]] = None,
|
||||
enforce_sharding: Optional[dict[str, int]] = None,
|
||||
) -> OutputSharding:
|
||||
"""
|
||||
Propagate the sharding of inputs to output for ops whose data moves according to einsum notation.
|
||||
|
|
@ -77,12 +77,12 @@ def einop_rule(
|
|||
# NOTE: only support single output unless needed in future
|
||||
output_dim = output_dims[0]
|
||||
|
||||
dim_to_sharding: Dict[str, int] = {}
|
||||
dim_to_size: Dict[str, int] = {}
|
||||
dim_to_sharding: dict[str, int] = {}
|
||||
dim_to_size: dict[str, int] = {}
|
||||
# record pending sum, key is mesh dimension, value is pending sum
|
||||
# counter across input specs
|
||||
pending_sums_counter: Dict[int, int] = {}
|
||||
seen_shardings: Dict[int, str] = {}
|
||||
pending_sums_counter: dict[int, int] = {}
|
||||
seen_shardings: dict[int, str] = {}
|
||||
needs_reshard = False
|
||||
|
||||
def merge_sharding(dim: str, a: int, b: int) -> int:
|
||||
|
|
@ -240,7 +240,7 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi
|
|||
input_specs = op_schema.args_spec
|
||||
max_dim = max(input.ndim for input in input_specs)
|
||||
dimchars = []
|
||||
singleton_counter: List[int] = [0] * max_dim
|
||||
singleton_counter: list[int] = [0] * max_dim
|
||||
for input in input_specs:
|
||||
start_dim = max_dim - input.ndim
|
||||
p = alphabet[start_dim:max_dim]
|
||||
|
|
@ -271,7 +271,7 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi
|
|||
|
||||
fmt = f"{','.join(p for p in dimchars)}->{out_dimchars}"
|
||||
|
||||
enforce_sharding: Dict[str, int] = {}
|
||||
enforce_sharding: dict[str, int] = {}
|
||||
if _is_inplace_op(op_schema.op):
|
||||
# inplace op should keep the input sharding it writes to
|
||||
for out_dimchar, mesh_dim in zip(out_dimchars, input_specs[0].dim_map):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# implement matrix related ops for distributed tensor
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
|
|
@ -32,9 +31,9 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding:
|
|||
assert weight_spec.tensor_meta is not None
|
||||
in_shape = input_spec.tensor_meta.shape
|
||||
weight_shape = weight_spec.tensor_meta.shape
|
||||
assert isinstance(stride, List)
|
||||
assert isinstance(padding, List)
|
||||
assert isinstance(dilation, List)
|
||||
assert isinstance(stride, list)
|
||||
assert isinstance(padding, list)
|
||||
assert isinstance(dilation, list)
|
||||
assert isinstance(weight_shape, torch.Size)
|
||||
N, H_in, W_in = in_shape[0], in_shape[2], in_shape[3]
|
||||
C_out = weight_shape[0]
|
||||
|
|
@ -84,7 +83,7 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
|
|||
assert isinstance(grad_output_spec, DTensorSpec)
|
||||
assert isinstance(input_spec, DTensorSpec)
|
||||
assert isinstance(weight_spec, DTensorSpec)
|
||||
assert isinstance(bias_shape_opt, List)
|
||||
assert isinstance(bias_shape_opt, list)
|
||||
assert input_spec.tensor_meta is not None
|
||||
weight_tensor_meta = weight_spec.tensor_meta
|
||||
bias_tensor_meta = TensorMeta(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Set
|
||||
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
||||
|
|
@ -15,13 +14,13 @@ from torch.distributed.tensor.placement_types import (
|
|||
|
||||
@dataclass
|
||||
class EinsumDims:
|
||||
contracting_dims: List[str]
|
||||
batch_dims: List[str]
|
||||
lhs_out_only_dims: List[str]
|
||||
rhs_out_only_dims: List[str]
|
||||
contracting_dims: list[str]
|
||||
batch_dims: list[str]
|
||||
lhs_out_only_dims: list[str]
|
||||
rhs_out_only_dims: list[str]
|
||||
|
||||
@classmethod
|
||||
def parse_equation(cls, equation: str) -> tuple[List[str], str]:
|
||||
def parse_equation(cls, equation: str) -> tuple[list[str], str]:
|
||||
# parse einop equation and extract arg specs
|
||||
"""
|
||||
Parse the einsum equation str to input dim chars and output dim char
|
||||
|
|
@ -37,12 +36,12 @@ class EinsumDims:
|
|||
return input_dims, output_dim
|
||||
|
||||
@classmethod
|
||||
def parse_dims(cls, input_dims: List[str], output_dim: str) -> "EinsumDims":
|
||||
def parse_dims(cls, input_dims: list[str], output_dim: str) -> "EinsumDims":
|
||||
"""
|
||||
Parse the dims and extract the contracting, batch, and free dimensions
|
||||
for the left and right hand sides.
|
||||
"""
|
||||
dim_char_set: Set[str] = set()
|
||||
dim_char_set: set[str] = set()
|
||||
for input_dim in input_dims:
|
||||
dim_char_set.update(input_dim)
|
||||
|
||||
|
|
@ -104,7 +103,7 @@ def gen_einsum_strategies(
|
|||
|
||||
# placement list stores placements of [output, input1, input2, ...]
|
||||
# first we always have replicate all for inputs and output
|
||||
placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1)
|
||||
placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
|
||||
mesh_dim_strategies.append(placement_list)
|
||||
|
||||
# split batch dim
|
||||
|
|
@ -131,7 +130,7 @@ def gen_einsum_strategies(
|
|||
lhs_free_dim = output_dim.index(lhs_dim)
|
||||
# this means split the lhs input and output
|
||||
# i.e. S(0), R -> S(0)
|
||||
lhs_placement_list: List[Placement] = [
|
||||
lhs_placement_list: list[Placement] = [
|
||||
Shard(lhs_free_dim),
|
||||
Shard(lhs_free_dim),
|
||||
Replicate(),
|
||||
|
|
@ -141,7 +140,7 @@ def gen_einsum_strategies(
|
|||
# split rhs free dim
|
||||
for rhs_dim in edims.rhs_out_only_dims:
|
||||
rhs_free_dim = output_dim.index(rhs_dim)
|
||||
rhs_placement_list: List[Placement] = [
|
||||
rhs_placement_list: list[Placement] = [
|
||||
Shard(rhs_free_dim),
|
||||
Replicate(),
|
||||
Shard(rhs_free_dim),
|
||||
|
|
@ -150,7 +149,7 @@ def gen_einsum_strategies(
|
|||
|
||||
# linearity strategy
|
||||
if linearity:
|
||||
linearity_placement_list: List[Placement] = [Partial()]
|
||||
linearity_placement_list: list[Placement] = [Partial()]
|
||||
for input_dim in input_dims:
|
||||
linearity_placement_list.append(Partial())
|
||||
mesh_dim_strategies.append(linearity_placement_list)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import cast, List, Optional, Sequence, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
|
@ -148,11 +149,11 @@ class _NormPartial(Partial):
|
|||
return 1 + hash(self.norm_type)
|
||||
|
||||
|
||||
def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[List[int]]:
|
||||
def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[list[int]]:
|
||||
if dims_arg is None:
|
||||
return None
|
||||
dims = cast(List[int], as_list(dims_arg))
|
||||
dims = cast(List[int], normalize_dims(dims, ndim))
|
||||
dims = cast(list[int], as_list(dims_arg))
|
||||
dims = cast(list[int], normalize_dims(dims, ndim))
|
||||
empty_dims = [[0], [-1], []]
|
||||
if ndim == 0 and dims_arg in empty_dims:
|
||||
return None
|
||||
|
|
@ -160,8 +161,8 @@ def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[List[int]]:
|
|||
|
||||
|
||||
def _infer_reduce_dims_map(
|
||||
reduction_dims: List[int], input_ndim: int, keep_dim=False
|
||||
) -> List[int]:
|
||||
reduction_dims: list[int], input_ndim: int, keep_dim=False
|
||||
) -> list[int]:
|
||||
reduction_dims_map = []
|
||||
new_dim_count = 0
|
||||
for input_dim in range(input_ndim):
|
||||
|
|
@ -179,7 +180,7 @@ def _infer_reduce_dims_map(
|
|||
def _replicate_dims_start_at(
|
||||
placements: Sequence[Placement], start_dim: int = 0
|
||||
) -> tuple[Placement, ...]:
|
||||
new_placements: List[Placement] = []
|
||||
new_placements: list[Placement] = []
|
||||
for p in placements:
|
||||
if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
|
||||
new_placements.append(Replicate()) # make it replicate
|
||||
|
|
@ -192,7 +193,7 @@ def _replicate_dims_start_at(
|
|||
def _skip_dim(
|
||||
placements: tuple[Placement, ...], skipped_dim: int
|
||||
) -> tuple[Placement, ...]:
|
||||
new_placements: List[Placement] = []
|
||||
new_placements: list[Placement] = []
|
||||
for p in placements:
|
||||
if isinstance(p, Shard) and p.dim >= skipped_dim:
|
||||
new_placements.append(Shard(p.dim - 1))
|
||||
|
|
@ -202,10 +203,10 @@ def _skip_dim(
|
|||
|
||||
|
||||
def replicate_reduction_dims(
|
||||
placements: tuple[Placement, ...], reduction_dims: List[int]
|
||||
placements: tuple[Placement, ...], reduction_dims: list[int]
|
||||
) -> tuple[Placement, ...]:
|
||||
# replicate the reduction dims if not reduction_linear
|
||||
new_placements: List[Placement] = []
|
||||
new_placements: list[Placement] = []
|
||||
|
||||
for p in placements:
|
||||
if p.is_partial():
|
||||
|
|
@ -220,14 +221,14 @@ def replicate_reduction_dims(
|
|||
|
||||
def map_placements_after_reduction(
|
||||
placements: tuple[Placement, ...],
|
||||
reduction_dims: List[int],
|
||||
reduction_dims_map: List[int],
|
||||
reduction_dims: list[int],
|
||||
reduction_dims_map: list[int],
|
||||
reduction_op: ReductionOpType,
|
||||
) -> tuple[Placement, ...]:
|
||||
"""
|
||||
Map each placement based on the output shape after reduction.
|
||||
"""
|
||||
new_placements: List[Placement] = []
|
||||
new_placements: list[Placement] = []
|
||||
for placement in placements:
|
||||
if isinstance(placement, (Replicate, Partial)):
|
||||
new_placements.append(placement)
|
||||
|
|
@ -253,7 +254,7 @@ def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement:
|
|||
def common_reduction_strategy(
|
||||
mesh: DeviceMesh,
|
||||
input_strategy: OpStrategy,
|
||||
reduce_dims: List[int],
|
||||
reduce_dims: list[int],
|
||||
keep_dim: bool = False,
|
||||
reduction_linear: bool = True,
|
||||
reduction_op: ReductionOpType = "sum",
|
||||
|
|
@ -406,7 +407,7 @@ def foreach_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> TupleStrateg
|
|||
assert isinstance(input_tuple_strategy, TupleStrategy)
|
||||
norm_type = args_schema[1] if len(args_schema) > 1 else 2
|
||||
assert isinstance(norm_type, (int, float, str)), f"{norm_type}"
|
||||
output_tuple_strategy_childs: List[OpStrategy] = []
|
||||
output_tuple_strategy_childs: list[OpStrategy] = []
|
||||
for op_strategy in input_tuple_strategy.childs:
|
||||
assert isinstance(op_strategy, OpStrategy), f"{op_strategy}"
|
||||
reduce_dims = list(range(op_strategy.ndim))
|
||||
|
|
@ -451,7 +452,7 @@ def linalg_replicate_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrate
|
|||
args_schema = op_schema.args_schema
|
||||
input_strategy = args_schema[0]
|
||||
assert isinstance(input_strategy, OpStrategy), f"{input_strategy}"
|
||||
output_strategies: List[PlacementStrategy] = []
|
||||
output_strategies: list[PlacementStrategy] = []
|
||||
for placement_strategy in input_strategy.strategies:
|
||||
replicate_placements = tuple(Replicate() for _ in range(mesh.ndim))
|
||||
replicate_spec = DTensorSpec(
|
||||
|
|
@ -912,14 +913,14 @@ def layer_norm_bwd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
|
|||
axis = input_ndim - len(normalized_size)
|
||||
outer_dims = list(range(axis))
|
||||
|
||||
assert isinstance(output_mask, List) and len(output_mask) == 3
|
||||
assert isinstance(output_mask, list) and len(output_mask) == 3
|
||||
|
||||
# output triple: (d_input, d_weight, d_bias)
|
||||
out_tuple_strategy = OpStrategy([])
|
||||
for idx, input_placement_strategy in enumerate(input_strategy.strategies):
|
||||
# args for PlacementStrategy
|
||||
output_specs_list: List[Optional[DTensorSpec]] = []
|
||||
input_specs_list: List[DTensorSpec] = []
|
||||
output_specs_list: list[Optional[DTensorSpec]] = []
|
||||
input_specs_list: list[DTensorSpec] = []
|
||||
redistribute_costs = []
|
||||
|
||||
input_src_spec = input_placement_strategy.output_spec
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# implement matrix related ops for distributed tensor
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
|
@ -415,7 +414,7 @@ def scaled_dot_product_efficient_attention_strategy(
|
|||
has_attn_bias = op_schema.args_schema[3] is not None
|
||||
compute_log_sumexp = op_schema.args_schema[4]
|
||||
|
||||
single_mesh_dim_strategies: List[PlacementList] = []
|
||||
single_mesh_dim_strategies: list[PlacementList] = []
|
||||
|
||||
# placement list stores placements of [outputs, inputs]
|
||||
# in the spda case, we have 2 valid tensor outputs and 3 or 4 tensor inputs
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from typing import List, Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
|
@ -464,7 +464,7 @@ def common_pointwise_strategy(
|
|||
|
||||
for placement_strategy in followed_strategy.strategies:
|
||||
spec_to_follow = placement_strategy.output_spec
|
||||
out_placements: List[Placement] = []
|
||||
out_placements: list[Placement] = []
|
||||
for placement in spec_to_follow.placements:
|
||||
if isinstance(placement, Shard):
|
||||
shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape))
|
||||
|
|
@ -479,8 +479,8 @@ def common_pointwise_strategy(
|
|||
else:
|
||||
out_placements.append(placement)
|
||||
|
||||
input_specs: List[DTensorSpec] = []
|
||||
redistribute_costs: List[List[float]] = []
|
||||
input_specs: list[DTensorSpec] = []
|
||||
redistribute_costs: list[list[float]] = []
|
||||
for input_arg in args_schema:
|
||||
if isinstance(input_arg, OpStrategy):
|
||||
# every arg follow the out_placements, but need to handle broadcasting
|
||||
|
|
@ -616,11 +616,11 @@ def list_pointwise_strategy(
|
|||
OpStrategy: generated strategy
|
||||
"""
|
||||
|
||||
def args_tuple_strategies(args_schema: tuple[object, ...]) -> List[TupleStrategy]:
|
||||
def args_tuple_strategies(args_schema: tuple[object, ...]) -> list[TupleStrategy]:
|
||||
first_arg = args_schema[0]
|
||||
assert isinstance(first_arg, TupleStrategy)
|
||||
strategy_len = len(first_arg.childs)
|
||||
tuple_strategies: List[TupleStrategy] = []
|
||||
tuple_strategies: list[TupleStrategy] = []
|
||||
for arg_idx, arg in enumerate(args_schema):
|
||||
if isinstance(arg, TupleStrategy):
|
||||
# every tuple strategy should have the same length
|
||||
|
|
@ -639,10 +639,10 @@ def list_pointwise_strategy(
|
|||
|
||||
args_strategies = args_tuple_strategies(op_schema.args_schema)
|
||||
follow_strategy: TupleStrategy = args_strategies[0]
|
||||
list_strategy: List[OpStrategy] = []
|
||||
list_strategy: list[OpStrategy] = []
|
||||
for child_idx, child_strtgy in enumerate(follow_strategy.childs):
|
||||
assert isinstance(child_strtgy, OpStrategy)
|
||||
args_schema: List[StrategyType] = [
|
||||
args_schema: list[StrategyType] = [
|
||||
arg_strategy.childs[child_idx] for arg_strategy in args_strategies
|
||||
]
|
||||
pointwise_strategy: OpStrategy = common_pointwise_strategy(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from typing import cast, List, Optional, Sequence, Sized
|
||||
from collections.abc import Sequence, Sized
|
||||
from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
|
@ -465,7 +466,7 @@ def _derive_follow_placements_from_tuple_strategy(
|
|||
# current replicate, just follow new placement
|
||||
return new_placement
|
||||
|
||||
follow_placements: Optional[List[Placement]] = None
|
||||
follow_placements: Optional[list[Placement]] = None
|
||||
for arg_strategy in tuple_strategy.childs:
|
||||
assert isinstance(arg_strategy, OpStrategy)
|
||||
for placement_strategy in arg_strategy.strategies:
|
||||
|
|
@ -489,7 +490,7 @@ def normalize_shard_for_stack(
|
|||
) -> Sequence[Placement]:
|
||||
# stack op would "insert" new dim, so all sharded dim >= the inserted dim need to
|
||||
# be normalized with the new Shard placement
|
||||
normalized_placements: List[Placement] = []
|
||||
normalized_placements: list[Placement] = []
|
||||
for placement in placements:
|
||||
if isinstance(placement, Shard) and placement.dim >= insert_dim:
|
||||
normalized_placements.append(Shard(placement.dim + 1))
|
||||
|
|
@ -575,7 +576,7 @@ def prop_index_select(op_schema: OpSchema) -> OutputSharding:
|
|||
assert isinstance(dim, int)
|
||||
assert isinstance(indices_spec, DTensorSpec)
|
||||
|
||||
all_indices_spec: List[Optional[DTensorSpec]] = [
|
||||
all_indices_spec: list[Optional[DTensorSpec]] = [
|
||||
indices_spec if dim == i else None for i in range(values_spec.ndim)
|
||||
]
|
||||
|
||||
|
|
@ -620,8 +621,8 @@ def prop_index(op_schema: OpSchema) -> OutputSharding:
|
|||
values_spec, multi_indices_spec = op_schema.args_schema
|
||||
assert isinstance(values_spec, DTensorSpec)
|
||||
assert isinstance(multi_indices_spec, list)
|
||||
multi_indices_spec = cast(List[Optional[DTensorSpec]], multi_indices_spec)
|
||||
valid_indices_spec: List[tuple[int, DTensorSpec]] = [
|
||||
multi_indices_spec = cast(list[Optional[DTensorSpec]], multi_indices_spec)
|
||||
valid_indices_spec: list[tuple[int, DTensorSpec]] = [
|
||||
(i, a) for i, a in enumerate(multi_indices_spec) if a is not None
|
||||
]
|
||||
|
||||
|
|
@ -731,7 +732,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding:
|
|||
schema_info=RuntimeSchemaInfo(1),
|
||||
)
|
||||
def split_rule(op_schema: OpSchema) -> OutputSharding:
|
||||
output_spec_list: List[DTensorSpec] = []
|
||||
output_spec_list: list[DTensorSpec] = []
|
||||
input_spec = cast(DTensorSpec, op_schema.args_schema[0])
|
||||
ndim = input_spec.ndim
|
||||
split_size_or_sections = op_schema.args_schema[1]
|
||||
|
|
@ -769,7 +770,7 @@ def split_rule(op_schema: OpSchema) -> OutputSharding:
|
|||
),
|
||||
)
|
||||
|
||||
def size_split(N, i) -> List:
|
||||
def size_split(N, i) -> list:
|
||||
# Last chunk will be smaller if the tensor size N
|
||||
# along the given dimension dim is not divisible by i.
|
||||
assert i > 0
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Set, Union
|
||||
from typing import Callable, cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -216,7 +217,7 @@ def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap:
|
|||
# other dims are passed through
|
||||
if end_dim < 0:
|
||||
end_dim += ndim
|
||||
results: List[DimSpec] = [InputDim(i) for i in range(start_dim)]
|
||||
results: list[DimSpec] = [InputDim(i) for i in range(start_dim)]
|
||||
results.append(
|
||||
Flatten.new(tuple(InputDim(i) for i in range(start_dim, end_dim + 1)))
|
||||
)
|
||||
|
|
@ -414,7 +415,7 @@ def dim_unsqueeze(ndim: int, dim: int) -> DimMap:
|
|||
|
||||
def dim_view_as_real(shape: Shape) -> DimMap:
|
||||
ndim = len(shape)
|
||||
results: List[DimSpec] = [InputDim(i) for i in range(ndim - 1)]
|
||||
results: list[DimSpec] = [InputDim(i) for i in range(ndim - 1)]
|
||||
# each complex number is split into two real numbers,
|
||||
# resulting in one more dimension of size 2
|
||||
results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 0))
|
||||
|
|
@ -442,7 +443,7 @@ def dim_reduction(
|
|||
)
|
||||
|
||||
|
||||
dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = {
|
||||
dim_maps: dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = {
|
||||
torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1),
|
||||
torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2),
|
||||
torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim),
|
||||
|
|
@ -488,11 +489,11 @@ def propagate_shape_and_sharding(
|
|||
assert len(input_src_placements) == len(mesh_sizes)
|
||||
# for each input dim, for each mesh dim, provides a list of possible shardable dimensions
|
||||
mesh_ndim = len(mesh_sizes)
|
||||
shardable_dims: Dict[int, List[bool]] = {}
|
||||
shardable_dims: dict[int, list[bool]] = {}
|
||||
|
||||
# in case an input dimension disappears (e.g. collapsing, reduction)
|
||||
# we cannot shard in that dimension (we need a replication fall-back rule)
|
||||
seen_input_dims: Set[int] = set()
|
||||
seen_input_dims: set[int] = set()
|
||||
|
||||
def collect_used_inputs(cmd: DimSpec) -> None:
|
||||
if isinstance(cmd, InputDim):
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@
|
|||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Callable, cast, Iterable, List, Optional, Sequence, TypeVar, Union
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Callable, cast, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
|
|
@ -35,7 +36,7 @@ _P = ParamSpec("_P")
|
|||
# pyre-fixme[3]: Return type must be annotated.
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
def register_prop_rule(
|
||||
op: Union[torch._ops.OpOverload, List[torch._ops.OpOverload]],
|
||||
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
|
||||
schema_info: Optional[RuntimeSchemaInfo] = None,
|
||||
) -> Callable[
|
||||
[Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding]
|
||||
|
|
@ -101,9 +102,9 @@ def register_op_strategy(
|
|||
|
||||
|
||||
def as_list(
|
||||
x: Union[List[object], object]
|
||||
x: Union[list[object], object]
|
||||
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
|
||||
) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
|
||||
) -> Union[list[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type]
|
||||
# During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
|
||||
# which is an object but treated as a list by the tracer. Therefore, keep
|
||||
# `immutable_list` intact here as well.
|
||||
|
|
@ -178,7 +179,7 @@ def is_tensor_partial(spec: DTensorSpec) -> bool:
|
|||
|
||||
def infer_broadcast_dims_map(
|
||||
common_shape: torch.Size, input_shape: torch.Size
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
# infer the broadcast dims map, where it maps from the common shape dim to the input shape dim
|
||||
# this is aligned with the broadcast semantics
|
||||
common_ndim = len(common_shape)
|
||||
|
|
@ -193,10 +194,10 @@ def infer_broadcast_dims_map(
|
|||
def map_placements_after_broadcast(
|
||||
placements: tuple[Placement, ...],
|
||||
shape: torch.Size,
|
||||
broadcast_dims_map: List[int],
|
||||
broadcast_dims_map: list[int],
|
||||
) -> tuple[Placement, ...]:
|
||||
"""Map each placement based on the output shape after broadcast."""
|
||||
new_placements: List[Placement] = []
|
||||
new_placements: list[Placement] = []
|
||||
for placement in placements:
|
||||
if isinstance(placement, (Replicate, Partial)):
|
||||
new_placements.append(placement)
|
||||
|
|
@ -223,8 +224,8 @@ def map_placements_after_broadcast(
|
|||
|
||||
def generate_redistribute_costs(
|
||||
src_strategy: OpStrategy, dst_spec: DTensorSpec
|
||||
) -> List[float]:
|
||||
redistribute_costs: List[float] = [
|
||||
) -> list[float]:
|
||||
redistribute_costs: list[float] = [
|
||||
redistribute_cost(strat.output_spec, dst_spec)
|
||||
for strat in src_strategy.strategies
|
||||
]
|
||||
|
|
@ -235,7 +236,7 @@ def generate_redistribute_costs(
|
|||
def expand_to_full_mesh_op_strategy(
|
||||
mesh: DeviceMesh,
|
||||
op_schema: OpSchema,
|
||||
single_mesh_dim_strategies: List[PlacementList],
|
||||
single_mesh_dim_strategies: list[PlacementList],
|
||||
*,
|
||||
input_index: int = 1,
|
||||
inplace_op: bool = False,
|
||||
|
|
@ -247,14 +248,14 @@ def expand_to_full_mesh_op_strategy(
|
|||
|
||||
all_strategies = []
|
||||
for strategy_comb in strategy_combs:
|
||||
spec_list: List[Optional[DTensorSpec]] = []
|
||||
spec_list: list[Optional[DTensorSpec]] = []
|
||||
for specs in zip(*strategy_comb):
|
||||
if specs[0] is not None:
|
||||
spec_list.append(DTensorSpec(mesh, specs))
|
||||
else:
|
||||
spec_list.append(None)
|
||||
|
||||
input_specs: List[DTensorSpec] = [
|
||||
input_specs: list[DTensorSpec] = [
|
||||
s for s in spec_list[input_index:] if isinstance(s, DTensorSpec)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import contextlib
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -110,12 +110,12 @@ class _RNGStateTracker:
|
|||
f"{self.__class__.__name__} instantiation requires the presence of CUDA/CUDA-like device"
|
||||
)
|
||||
|
||||
self._states: Dict[str, Tensor] = {}
|
||||
self._states: dict[str, Tensor] = {}
|
||||
self._devices = [self._device_handle.current_device()]
|
||||
self._use_distribute_region = True
|
||||
|
||||
@property
|
||||
def rng_states(self) -> Dict[str, Tensor]:
|
||||
def rng_states(self) -> dict[str, Tensor]:
|
||||
return self._states
|
||||
|
||||
@property
|
||||
|
|
@ -267,7 +267,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
|||
mesh = spec.mesh
|
||||
# note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP
|
||||
# case. Replace the custom logic with dim_map once we support it.
|
||||
dim_map: List[Union[int, List[int]]] = [-1] * spec.ndim
|
||||
dim_map: list[Union[int, list[int]]] = [-1] * spec.ndim
|
||||
for i, placement in enumerate(spec.placements):
|
||||
if isinstance(placement, Shard):
|
||||
shard_dim = placement.dim
|
||||
|
|
@ -275,7 +275,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
|||
dim_map[shard_dim] = [i]
|
||||
else:
|
||||
mesh_dim_list = dim_map[shard_dim]
|
||||
assert isinstance(mesh_dim_list, List)
|
||||
assert isinstance(mesh_dim_list, list)
|
||||
mesh_dim_list.append(i)
|
||||
|
||||
# Compute shard coordinate:
|
||||
|
|
@ -291,7 +291,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
|||
shard_idx = 0
|
||||
total_num_shards = 1
|
||||
# the tensor dim is sharded on more than 1 mesh dim
|
||||
if isinstance(mesh_dim, List):
|
||||
if isinstance(mesh_dim, list):
|
||||
rank_coord = [mesh_coordinate[d] for d in mesh_dim]
|
||||
num_shards = [mesh_size[d] for d in mesh_dim]
|
||||
# compute the shard idx and total number of shards
|
||||
|
|
@ -356,7 +356,7 @@ class OffsetBasedRNGTracker(_RNGStateTracker):
|
|||
self.set_offset("parallel-rng", old_offset + numel)
|
||||
|
||||
def _calc_shard_linear_idx(
|
||||
self, shard_coord: List[int], shard_size: List[int]
|
||||
self, shard_coord: list[int], shard_size: list[int]
|
||||
) -> int:
|
||||
# compute shard linear index
|
||||
shard_linear_idx = 0
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import cast, List, NamedTuple
|
||||
from functools import cache
|
||||
from typing import cast, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
|
@ -24,13 +24,13 @@ class _TransformInfo(NamedTuple):
|
|||
mesh_dim: int
|
||||
src_dst_placements: tuple[Placement, Placement]
|
||||
# logical_shape on this mesh dimension
|
||||
logical_shape: List[int]
|
||||
logical_shape: list[int]
|
||||
|
||||
|
||||
def _gen_transform_infos_non_cached(
|
||||
src_spec: DTensorSpec,
|
||||
dst_spec: DTensorSpec,
|
||||
) -> List[_TransformInfo]:
|
||||
) -> list[_TransformInfo]:
|
||||
"""
|
||||
Generate the transform infos from the source placements to the target placements.
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ def _gen_transform_infos_non_cached(
|
|||
the former is a nested-sharding of a tensor already already sharded dimension 0, whereras
|
||||
the latter is the first sharding on tensor dimension 0.
|
||||
"""
|
||||
transform_infos: List[_TransformInfo] = []
|
||||
transform_infos: list[_TransformInfo] = []
|
||||
|
||||
device_mesh = src_spec.device_mesh
|
||||
my_coordinate = device_mesh.get_coordinate()
|
||||
|
|
@ -145,11 +145,11 @@ def _gen_transform_infos_non_cached(
|
|||
return transform_infos
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@cache
|
||||
def _gen_transform_infos(
|
||||
src_spec: DTensorSpec,
|
||||
dst_spec: DTensorSpec,
|
||||
) -> List[_TransformInfo]:
|
||||
) -> list[_TransformInfo]:
|
||||
return _gen_transform_infos_non_cached(src_spec, dst_spec)
|
||||
|
||||
|
||||
|
|
@ -332,7 +332,7 @@ class Redistribute(torch.autograd.Function):
|
|||
is_backward=True,
|
||||
)
|
||||
# normalize the target placement to replicate if it is partial
|
||||
normalized_placements: List[Placement] = []
|
||||
normalized_placements: list[Placement] = []
|
||||
for previous_placement in previous_spec.placements:
|
||||
if previous_placement.is_partial():
|
||||
# keep target placement to replicate instead of partial in this case
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import threading
|
||||
from collections.abc import Sequence
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
from typing import Callable, cast, Dict, List, Optional, Sequence, Union
|
||||
from typing import Callable, cast, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverload
|
||||
|
|
@ -51,18 +52,18 @@ class LocalLRUCache(threading.local):
|
|||
|
||||
class ShardingPropagator:
|
||||
def __init__(self) -> None:
|
||||
self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
|
||||
self.op_strategy_funcs: Dict[
|
||||
self.op_to_rules: dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
|
||||
self.op_strategy_funcs: dict[
|
||||
OpOverload,
|
||||
Callable[[DeviceMesh, OpSchema], StrategyType],
|
||||
] = {}
|
||||
# op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop
|
||||
self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {}
|
||||
self.op_to_schema_info: dict[OpOverload, RuntimeSchemaInfo] = {}
|
||||
self.propagate_op_sharding = LocalLRUCache(
|
||||
self.propagate_op_sharding_non_cached
|
||||
)
|
||||
# op map to save indices of shape (and stride) args which may need to be modified in sharding prop
|
||||
self.op_to_shape_and_stride_idx: Dict[
|
||||
self.op_to_shape_and_stride_idx: dict[
|
||||
OpOverload, Union[int, tuple[int, int]]
|
||||
] = {
|
||||
# new factory ops
|
||||
|
|
@ -128,7 +129,7 @@ class ShardingPropagator:
|
|||
)
|
||||
|
||||
elif isinstance(fake_out, (tuple, list)):
|
||||
tensor_meta_list: List[Optional[TensorMeta]] = []
|
||||
tensor_meta_list: list[Optional[TensorMeta]] = []
|
||||
for fake_out_item in fake_out:
|
||||
if isinstance(fake_out_item, torch.Tensor):
|
||||
tensor_meta_list.append(
|
||||
|
|
@ -260,7 +261,7 @@ class ShardingPropagator:
|
|||
|
||||
# check if we need to redistribute the input
|
||||
needs_redistribute = False
|
||||
expected_input_specs: List[DTensorSpec] = []
|
||||
expected_input_specs: list[DTensorSpec] = []
|
||||
|
||||
# in case where the op does not specify input_specs and output_specs
|
||||
# is a DTensorSpec, we use output_specs as the spec for each DTensor
|
||||
|
|
@ -335,8 +336,8 @@ class ShardingPropagator:
|
|||
elif isinstance(op_strategy, TupleStrategy):
|
||||
# tuple strategy output sharding processing
|
||||
# runtime selected placement strategy for each TupleStrategy input arg
|
||||
selected_strategies: List[PlacementStrategy] = []
|
||||
out_spec_list: List[DTensorSpec] = []
|
||||
selected_strategies: list[PlacementStrategy] = []
|
||||
out_spec_list: list[DTensorSpec] = []
|
||||
for strategy in op_strategy.childs:
|
||||
assert isinstance(strategy, OpStrategy)
|
||||
selected_strategy = self._select_strategy(strategy)
|
||||
|
|
@ -344,7 +345,7 @@ class ShardingPropagator:
|
|||
out_spec_list.append(selected_strategy.output_spec)
|
||||
|
||||
needs_redistribute = False
|
||||
suggestion_args: List[object] = []
|
||||
suggestion_args: list[object] = []
|
||||
tensor_or_list_tensor_arg_idx = 0
|
||||
|
||||
for arg in op_schema.args_schema:
|
||||
|
|
@ -353,7 +354,7 @@ class ShardingPropagator:
|
|||
and isinstance(arg, (list, tuple))
|
||||
and isinstance(arg[0], DTensorSpec)
|
||||
):
|
||||
expected_input_spec_list: List[DTensorSpec] = []
|
||||
expected_input_spec_list: list[DTensorSpec] = []
|
||||
for idx, arg_spec in enumerate(arg):
|
||||
expected_input_spec = selected_strategies[idx].input_spec(
|
||||
tensor_or_list_tensor_arg_idx
|
||||
|
|
@ -461,7 +462,7 @@ class ShardingPropagator:
|
|||
# short cut with only one possible strategy
|
||||
return strategy.strategies[0]
|
||||
|
||||
strategy_costs: List[float] = []
|
||||
strategy_costs: list[float] = []
|
||||
for strtg in strategy.strategies:
|
||||
assert (
|
||||
strtg.redistribute_cost is not None
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
|
|
@ -34,12 +34,12 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
|
|||
"""
|
||||
|
||||
__slots__ = ["_local_shards", "_storage_meta"]
|
||||
_local_shards: List[torch.Tensor]
|
||||
_local_shards: list[torch.Tensor]
|
||||
_storage_meta: TensorStorageMetadata
|
||||
|
||||
@staticmethod
|
||||
def __new__(
|
||||
cls, local_shards: List[torch.Tensor], local_offsets: List[tuple[int, ...]]
|
||||
cls, local_shards: list[torch.Tensor], local_offsets: list[tuple[int, ...]]
|
||||
) -> "LocalShardsWrapper":
|
||||
assert len(local_shards) > 0
|
||||
assert len(local_shards) == len(local_offsets)
|
||||
|
|
@ -206,7 +206,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
|
|||
[shard.requires_grad_(requires_grad) for shard in self._local_shards]
|
||||
return self
|
||||
|
||||
def local_shards(self) -> List[torch.Tensor]:
|
||||
def local_shards(self) -> list[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of :class:`torch.Tensor' corresponding to the
|
||||
local shards for this rank. Returns an empty list if the current rank
|
||||
|
|
@ -214,7 +214,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
|
|||
"""
|
||||
return self._local_shards
|
||||
|
||||
def local_sizes(self) -> List[torch.Size]:
|
||||
def local_sizes(self) -> list[torch.Size]:
|
||||
"""
|
||||
Returns a list of :class:`torch.Size' corresponding to the
|
||||
local sizes for the shards on this rank. Returns an empty list if the current rank
|
||||
|
|
@ -222,7 +222,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
|
|||
"""
|
||||
return [chunk.sizes for chunk in self._storage_meta.chunks]
|
||||
|
||||
def local_offsets(self) -> List[torch.Size]:
|
||||
def local_offsets(self) -> list[torch.Size]:
|
||||
"""
|
||||
Returns a list of :class:`torch.Size' corresponding to the
|
||||
local offsets for the shards on this rank. Returns an empty list if the current rank
|
||||
|
|
@ -231,7 +231,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
|
|||
return [chunk.offsets for chunk in self._storage_meta.chunks]
|
||||
|
||||
@property
|
||||
def local_chunks(self) -> List[ChunkStorageMetadata]:
|
||||
def local_chunks(self) -> list[ChunkStorageMetadata]:
|
||||
"""
|
||||
Returns a :class:`List[ChunkStorageMetadata]` object corresponding to the
|
||||
metadata for each tensor shard
|
||||
|
|
@ -247,7 +247,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
|
|||
|
||||
def __create_write_items__(
|
||||
self, fqn: str, object: Any
|
||||
) -> List[WriteItem]: # pyre-ignore[2]
|
||||
) -> list[WriteItem]: # pyre-ignore[2]
|
||||
"""
|
||||
For compatibility with DCP, we support creation of WriteItems
|
||||
such that they can be saved properly.
|
||||
|
|
@ -268,7 +268,7 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new
|
|||
for tensor, chunks in zip(self.local_shards(), self.local_chunks)
|
||||
]
|
||||
|
||||
def __create_chunk_list__(self) -> List[ChunkStorageMetadata]:
|
||||
def __create_chunk_list__(self) -> list[ChunkStorageMetadata]:
|
||||
"""
|
||||
For compatibility with DCP, we support creation of chunk lists
|
||||
such that they can be saved properly.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# implement matrix related ops for distributed tensor
|
||||
from typing import cast, Dict, List
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -106,7 +106,7 @@ def _ring_send_recv_aggregate(grad_in_tensor, d1, d2, left, right, rank, size):
|
|||
def tp_convolution(
|
||||
op_call: torch._ops.OpOverload,
|
||||
local_tensor_args: tuple[object, ...],
|
||||
local_tensor_kwargs: Dict[str, object],
|
||||
local_tensor_kwargs: dict[str, object],
|
||||
) -> object:
|
||||
assert op_call == aten.convolution.default
|
||||
assert len(local_tensor_args) == 9
|
||||
|
|
@ -118,7 +118,7 @@ def tp_convolution(
|
|||
stride, padding, dilation = local_tensor_args[3:6]
|
||||
|
||||
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
|
||||
assert isinstance(padding, List)
|
||||
assert isinstance(padding, list)
|
||||
|
||||
if not _requires_data_exchange(padding):
|
||||
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
|
||||
|
|
@ -159,7 +159,7 @@ def tp_convolution(
|
|||
def tp_convolution_backward(
|
||||
op_call: torch._ops.OpOverload,
|
||||
local_tensor_args: tuple[object, ...],
|
||||
local_tensor_kwargs: Dict[str, object],
|
||||
local_tensor_kwargs: dict[str, object],
|
||||
) -> object:
|
||||
assert op_call == aten.convolution_backward.default
|
||||
assert len(local_tensor_args) == 11
|
||||
|
|
@ -172,7 +172,7 @@ def tp_convolution_backward(
|
|||
stride, padding, dilation = local_tensor_args[4:7]
|
||||
|
||||
assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation)
|
||||
assert isinstance(padding, List)
|
||||
assert isinstance(padding, list)
|
||||
|
||||
if not _requires_data_exchange(padding):
|
||||
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
|
||||
|
|
@ -229,7 +229,7 @@ def tp_convolution_backward(
|
|||
def convolution_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
# extract local tensor and sharding infos to a OpInfo
|
||||
op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
|
||||
|
|
@ -252,7 +252,7 @@ def convolution_handler(
|
|||
def convolution_backward_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
# Redistribute grad_output tensor to the same placement as input tensor
|
||||
args = list(args)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import cast, List, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.distributed.tensor._api as dtensor
|
||||
|
|
@ -165,7 +166,7 @@ def compute_local_shape_and_global_offset(
|
|||
|
||||
def compute_global_tensor_info(
|
||||
tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement]
|
||||
) -> tuple[List[int], List[int]]:
|
||||
) -> tuple[list[int], list[int]]:
|
||||
"""
|
||||
Compute the global size and stride of a DTensor from the given local tensor.
|
||||
The local size is multiplited by `world_size` per Sharding dim.
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import json
|
|||
import re
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
|
|
@ -240,7 +240,7 @@ class CommDebugMode(TorchDispatchMode):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.comm_counts: Dict[Any, int] = defaultdict(int)
|
||||
self.comm_counts: dict[Any, int] = defaultdict(int)
|
||||
self.comm_module_counts = {}
|
||||
self.comm_module_operation_counts = {}
|
||||
self.comm_registry = set()
|
||||
|
|
@ -392,7 +392,7 @@ class CommDebugMode(TorchDispatchMode):
|
|||
|
||||
return json_dict
|
||||
|
||||
json_dict: Dict[str, Any] = {}
|
||||
json_dict: dict[str, Any] = {}
|
||||
add_json_information(json_dict, "Global")
|
||||
|
||||
# converts dictonary into json file
|
||||
|
|
@ -567,7 +567,7 @@ class CommDebugMode(TorchDispatchMode):
|
|||
def get_total_counts(self) -> int:
|
||||
return sum(self.comm_counts.values())
|
||||
|
||||
def get_comm_counts(self) -> Dict[Any, int]:
|
||||
def get_comm_counts(self) -> dict[Any, int]:
|
||||
"""Returns the communication counts as a dictionary.
|
||||
|
||||
Returns:
|
||||
|
|
@ -575,10 +575,10 @@ class CommDebugMode(TorchDispatchMode):
|
|||
"""
|
||||
return self.comm_counts
|
||||
|
||||
def get_parameter_info(self) -> Dict[str, Dict[str, Any]]:
|
||||
def get_parameter_info(self) -> dict[str, dict[str, Any]]:
|
||||
return self.advanced_module_tracker.module_parameters_dict
|
||||
|
||||
def get_sharding_info(self) -> Dict[str, Dict[str, Any]]:
|
||||
def get_sharding_info(self) -> dict[str, dict[str, Any]]:
|
||||
return self.advanced_module_tracker.sharding_dict
|
||||
|
||||
def __enter__(self):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from operator import itemgetter
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
|
@ -13,7 +12,7 @@ from torch.distributed.tensor import DTensor
|
|||
|
||||
inductor_decomps = select_decomp_table()
|
||||
|
||||
graphs: List[torch.fx.GraphModule] = []
|
||||
graphs: list[torch.fx.GraphModule] = []
|
||||
|
||||
|
||||
def fwd_bwd_compiler(fx_g, _):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import List, Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -89,7 +89,7 @@ def _compute_local_shape_and_global_offset(
|
|||
global_shape: ShapeType,
|
||||
mesh: DeviceMesh,
|
||||
placements: Sequence[Placement],
|
||||
my_coordinate: List[int],
|
||||
my_coordinate: list[int],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
"""
|
||||
Same as torch.distributed._tensor._utils.compute_local_shape_and_global_offset but
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ torchrun --standalone --nnodes=1 --nproc-per-node=4 comm_mode_features_example.p
|
|||
"""
|
||||
import argparse
|
||||
import os
|
||||
from typing import Callable, Dict, Union
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -713,7 +713,7 @@ def run_example(world_size: int, rank: int, example_name: str) -> None:
|
|||
# intializing class with all of the functions
|
||||
instantiated_example = CommDebugModeExample(world_size, rank)
|
||||
# dict that stores example code function names
|
||||
name_to_example_code: Dict[str, Callable[[], None]] = {
|
||||
name_to_example_code: dict[str, Callable[[], None]] = {
|
||||
"MLP_distributed_sharding_display": instantiated_example.example_MLP_distributed_sharding_display,
|
||||
"MLPStacked_distributed_sharding_display": instantiated_example.example_MLPStacked_distributed_sharding_display,
|
||||
"MLP_module_tracing": instantiated_example.example_MLP_module_tracing,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ sharding with the DTensor API.
|
|||
import argparse
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
|
|
@ -43,12 +43,12 @@ supported_ops = [aten.view.default, aten._to_copy.default]
|
|||
# this torch.Tensor subclass is a wrapper around all local shards associated
|
||||
# with a single sharded embedding table.
|
||||
class LocalShardsWrapper(torch.Tensor):
|
||||
local_shards: List[torch.Tensor]
|
||||
local_shards: list[torch.Tensor]
|
||||
storage_meta: TensorStorageMetadata
|
||||
|
||||
@staticmethod
|
||||
def __new__(
|
||||
cls, local_shards: List[torch.Tensor], offsets: List[torch.Size]
|
||||
cls, local_shards: list[torch.Tensor], offsets: list[torch.Size]
|
||||
) -> "LocalShardsWrapper":
|
||||
assert len(local_shards) > 0
|
||||
assert len(local_shards) == len(offsets)
|
||||
|
|
@ -98,19 +98,19 @@ class LocalShardsWrapper(torch.Tensor):
|
|||
)
|
||||
|
||||
@property
|
||||
def shards(self) -> List[torch.Tensor]:
|
||||
def shards(self) -> list[torch.Tensor]:
|
||||
return self.local_shards
|
||||
|
||||
@shards.setter
|
||||
def shards(self, local_shards: List[torch.Tensor]):
|
||||
def shards(self, local_shards: list[torch.Tensor]):
|
||||
self.local_shards = local_shards
|
||||
|
||||
@cached_property
|
||||
def shard_sizes(self) -> List[torch.Size]:
|
||||
def shard_sizes(self) -> list[torch.Size]:
|
||||
return [chunk.sizes for chunk in self.storage_meta.chunks]
|
||||
|
||||
@cached_property
|
||||
def shard_offsets(self) -> List[torch.Size]:
|
||||
def shard_offsets(self) -> list[torch.Size]:
|
||||
return [chunk.offsets for chunk in self.storage_meta.chunks]
|
||||
|
||||
|
||||
|
|
@ -158,7 +158,7 @@ def run_torchrec_row_wise_even_sharding_example(rank, world_size):
|
|||
# this is the sharding placement we use in DTensor to represent row-wise sharding
|
||||
# row_wise_sharding_placements means that the global tensor is sharded by first dim
|
||||
# over the 1-d mesh.
|
||||
row_wise_sharding_placements: List[Placement] = [Shard(0)]
|
||||
row_wise_sharding_placements: list[Placement] = [Shard(0)]
|
||||
|
||||
# create a DTensor from the local shard
|
||||
dtensor = DTensor.from_local(
|
||||
|
|
@ -226,7 +226,7 @@ def run_torchrec_row_wise_uneven_sharding_example(rank, world_size):
|
|||
###########################################################################
|
||||
# example 1: transform local_shards into DTensor
|
||||
# create the DTensorMetadata which torchrec should provide
|
||||
row_wise_sharding_placements: List[Placement] = [Shard(0)]
|
||||
row_wise_sharding_placements: list[Placement] = [Shard(0)]
|
||||
|
||||
# note: for uneven sharding, we need to specify the shape and stride because
|
||||
# DTensor would assume even sharding and compute shape/stride based on the
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
|
||||
from torch.distributed.tensor._api import DTensor
|
||||
from torch.distributed.tensor.experimental._attention import context_parallel
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@ import logging
|
|||
import types
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Protocol, Set, Union
|
||||
from typing import Any, Callable, Optional, Protocol, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -435,7 +436,7 @@ def _templated_ring_attention(
|
|||
raise RuntimeError("Load balancing requires `is_causal=True`.")
|
||||
|
||||
if isinstance(mesh, dist.ProcessGroup):
|
||||
pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = mesh
|
||||
pg: Union[dist.ProcessGroup, list[dist.ProcessGroup]] = mesh
|
||||
else:
|
||||
pg = mesh.get_group()
|
||||
assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension"
|
||||
|
|
@ -452,7 +453,7 @@ def _templated_ring_attention(
|
|||
|
||||
sdpa_merger = _SDPAMerger(_cp_options.convert_to_f32, seq_dim=seq_dim)
|
||||
|
||||
rest: List[Any]
|
||||
rest: list[Any]
|
||||
out: torch.Tensor
|
||||
logsumexp: torch.Tensor
|
||||
|
||||
|
|
@ -519,7 +520,7 @@ def _templated_ring_attention(
|
|||
def _sdpa_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
# extract local tensor and sharding infos to a OpInfo
|
||||
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
|
||||
|
|
@ -557,7 +558,7 @@ def _sdpa_handler(
|
|||
def _sdpa_backward_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
# Redistribute grad_output tensor to the same placement as output tensor
|
||||
args = list(args)
|
||||
|
|
@ -614,7 +615,7 @@ def _templated_ring_attention_backward(
|
|||
size = dist.get_world_size(pg)
|
||||
next_kv = None
|
||||
next_grad_kv = None
|
||||
rest: List[Any]
|
||||
rest: list[Any]
|
||||
grad_query_, grad_key_, grad_value_ = None, None, None
|
||||
|
||||
accum_dtype = torch.float32 if _cp_options.convert_to_f32 else query.dtype
|
||||
|
|
@ -850,7 +851,7 @@ customized_ops = {
|
|||
}
|
||||
|
||||
|
||||
_replaced_functions: Dict[Callable, tuple[str, Callable]] = {}
|
||||
_replaced_functions: dict[Callable, tuple[str, Callable]] = {}
|
||||
|
||||
|
||||
def _distribute_function(
|
||||
|
|
@ -890,7 +891,7 @@ def _distribute_function(
|
|||
def wrapper(
|
||||
target_fn: Callable, input_fn: Optional[Callable], output_fn: Optional[Callable]
|
||||
) -> Callable:
|
||||
def inner_fn(*args: tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
def inner_fn(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any:
|
||||
if input_fn is not None:
|
||||
args, kwargs = input_fn(device_mesh, *args, **kwargs)
|
||||
output = target_fn(*args, **kwargs)
|
||||
|
|
@ -1045,8 +1046,8 @@ def _context_parallel(seq_dim: int, mesh: DeviceMesh) -> Generator[None, None, N
|
|||
"""Replace SDPA with the CP-wrapped version and enable DTensor CP dispatcher."""
|
||||
|
||||
def attention_input_fn(
|
||||
mesh: DeviceMesh, *args: tuple[Any, ...], **kwargs: Dict[str, Any]
|
||||
) -> tuple[tuple[Any, ...], Dict[str, Any]]:
|
||||
mesh: DeviceMesh, *args: tuple[Any, ...], **kwargs: dict[str, Any]
|
||||
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||
placement = [Shard(seq_dim)]
|
||||
all_args = []
|
||||
|
||||
|
|
@ -1180,9 +1181,9 @@ class _RoundRobinLoadBalancer(_LoadBalancer):
|
|||
|
||||
def _context_parallel_buffers(
|
||||
mesh: DeviceMesh,
|
||||
buffers: List[torch.Tensor],
|
||||
buffer_seq_dims: List[int],
|
||||
) -> List[torch.Tensor]:
|
||||
buffers: list[torch.Tensor],
|
||||
buffer_seq_dims: list[int],
|
||||
) -> list[torch.Tensor]:
|
||||
"""Shard the buffers along the sequence dimensions according to CP rules."""
|
||||
new_buffers = []
|
||||
sharder = (
|
||||
|
|
@ -1201,9 +1202,9 @@ def _context_parallel_buffers(
|
|||
def context_parallel(
|
||||
mesh: DeviceMesh,
|
||||
*,
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
buffer_seq_dims: Optional[List[int]] = None,
|
||||
no_restore_buffers: Optional[Set[torch.Tensor]] = None,
|
||||
buffers: Optional[list[torch.Tensor]] = None,
|
||||
buffer_seq_dims: Optional[list[int]] = None,
|
||||
no_restore_buffers: Optional[set[torch.Tensor]] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
|
||||
|
|
@ -1267,9 +1268,9 @@ def context_parallel(
|
|||
@torch.no_grad()
|
||||
def context_parallel_unshard(
|
||||
mesh: DeviceMesh,
|
||||
buffers: List[torch.Tensor],
|
||||
seq_dims: List[int],
|
||||
) -> List[torch.Tensor]:
|
||||
buffers: list[torch.Tensor],
|
||||
seq_dims: list[int],
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
Unshard the tensors (e.g., output) that are sharded due to context parallelism.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from typing import Callable, List, Sequence, Union
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverload
|
||||
|
|
@ -21,7 +22,7 @@ from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy
|
|||
__all__ = ["register_sharding"]
|
||||
|
||||
|
||||
def register_sharding(op: Union[OpOverload, List[OpOverload]]):
|
||||
def register_sharding(op: Union[OpOverload, list[OpOverload]]):
|
||||
"""
|
||||
:meth:`register_sharding` is an experimental API that allows users to register sharding
|
||||
strategies for an operator when the tensor inputs and outputs are DTensor.
|
||||
|
|
@ -89,7 +90,7 @@ def register_sharding(op: Union[OpOverload, List[OpOverload]]):
|
|||
|
||||
acceptable_shardings = custom_sharding_fn(*args_schema, **kwargs_schema)
|
||||
|
||||
single_mesh_dim_strategies: List[PlacementList] = []
|
||||
single_mesh_dim_strategies: list[PlacementList] = []
|
||||
for output_specs, input_specs in acceptable_shardings:
|
||||
single_mesh_dim_strategies.append(output_specs + input_specs)
|
||||
|
||||
|
|
@ -110,7 +111,7 @@ def register_sharding(op: Union[OpOverload, List[OpOverload]]):
|
|||
# 2. let static_kwargkey include all the int type kwargs
|
||||
# 3. always set needs_pytree=True
|
||||
static_argnum = 100
|
||||
static_kwargkey: List[str] = []
|
||||
static_kwargkey: list[str] = []
|
||||
for i, arg in enumerate(op._schema.arguments):
|
||||
if isinstance(arg.type, torch.IntType) or (
|
||||
isinstance(arg.type, torch.OptionalType)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import operator
|
||||
from typing import Any, cast, Dict, List, Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast, Optional
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
|
|
@ -36,7 +37,7 @@ def tensor_parallel_transformation(
|
|||
rank: int,
|
||||
world_size: int,
|
||||
device_type: str,
|
||||
parallel_strategies: Dict[str, ParallelStyle],
|
||||
parallel_strategies: dict[str, ParallelStyle],
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
The entry point function to perform graph transformations on an exported program
|
||||
|
|
@ -77,14 +78,14 @@ class _TensorParallelTransformPass(PassBase):
|
|||
rank: int,
|
||||
world_size: int,
|
||||
device_type: str,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
graph_signature: ExportGraphSignature,
|
||||
parallel_strategies: Dict[str, ParallelStyle],
|
||||
parallel_strategies: dict[str, ParallelStyle],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.rank = rank
|
||||
self.mesh = DeviceMesh(device_type, torch.arange(world_size))
|
||||
self.state_dict: Dict[str, torch.Tensor] = state_dict
|
||||
self.state_dict: dict[str, torch.Tensor] = state_dict
|
||||
self.graph_signature = graph_signature
|
||||
self.parallel_strategies = parallel_strategies
|
||||
|
||||
|
|
@ -105,13 +106,13 @@ class _TensorParallelTransformPass(PassBase):
|
|||
|
||||
|
||||
def _generate_parameter_and_buffer_placements(
|
||||
params_and_buffers: List[str],
|
||||
parallel_strategies: Dict[str, ParallelStyle],
|
||||
) -> Dict[str, Placement]:
|
||||
params_and_buffers: list[str],
|
||||
parallel_strategies: dict[str, ParallelStyle],
|
||||
) -> dict[str, Placement]:
|
||||
"""
|
||||
Build parameter placements based on the give parallel style of linear layers.
|
||||
"""
|
||||
parameter_placements: Dict[str, Placement] = {}
|
||||
parameter_placements: dict[str, Placement] = {}
|
||||
for linear_fqn, parallel_style in parallel_strategies.items():
|
||||
weight_fqn = f"{linear_fqn}.weight"
|
||||
bias_fqn = f"{linear_fqn}.bias"
|
||||
|
|
@ -130,12 +131,12 @@ def _mark_tensor_parallel_shardings(
|
|||
gm: GraphModule,
|
||||
graph_signature: ExportGraphSignature,
|
||||
mesh: DeviceMesh,
|
||||
parameter_placements: Dict[str, Placement],
|
||||
) -> Dict[Node, PlacementStrategy]:
|
||||
parameter_placements: dict[str, Placement],
|
||||
) -> dict[Node, PlacementStrategy]:
|
||||
"""
|
||||
Mark the placement strategies of the parameter and buffer placeholder nodes.
|
||||
"""
|
||||
placement_strategies: Dict[Node, PlacementStrategy] = {}
|
||||
placement_strategies: dict[Node, PlacementStrategy] = {}
|
||||
num_params_and_buffers = len(graph_signature.inputs_to_parameters) + len(
|
||||
graph_signature.inputs_to_buffers
|
||||
)
|
||||
|
|
@ -182,12 +183,12 @@ def _mark_sharding(
|
|||
gm: GraphModule,
|
||||
graph_signature: ExportGraphSignature,
|
||||
mesh: DeviceMesh,
|
||||
parameter_placements: Dict[str, Placement],
|
||||
) -> Dict[Node, PlacementStrategy]:
|
||||
parameter_placements: dict[str, Placement],
|
||||
) -> dict[Node, PlacementStrategy]:
|
||||
"""
|
||||
Mark the sharding strategy for each node in the graph module.
|
||||
"""
|
||||
placement_strategies: Dict[
|
||||
placement_strategies: dict[
|
||||
Node, PlacementStrategy
|
||||
] = _mark_tensor_parallel_shardings(gm, graph_signature, mesh, parameter_placements)
|
||||
|
||||
|
|
@ -485,12 +486,12 @@ def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None:
|
|||
|
||||
|
||||
def _get_input_node_specs(
|
||||
node: Node, placement_strategies: Dict[Node, PlacementStrategy]
|
||||
node: Node, placement_strategies: dict[Node, PlacementStrategy]
|
||||
) -> tuple[DTensorSpec, ...]:
|
||||
"""
|
||||
Get the input specs of a node.
|
||||
"""
|
||||
input_specs_list: List[DTensorSpec] = []
|
||||
input_specs_list: list[DTensorSpec] = []
|
||||
for input_arg in node.all_input_nodes:
|
||||
if input_arg in placement_strategies:
|
||||
output_spec = placement_strategies[input_arg].output_specs
|
||||
|
|
@ -502,7 +503,7 @@ def _get_input_node_specs(
|
|||
|
||||
|
||||
def _get_op_schema(
|
||||
node: Node, placement_strategies: Dict[Node, PlacementStrategy]
|
||||
node: Node, placement_strategies: dict[Node, PlacementStrategy]
|
||||
) -> OpSchema:
|
||||
"""
|
||||
Util function to construct the operator schema of a node.
|
||||
|
|
@ -513,14 +514,14 @@ def _get_op_schema(
|
|||
op_schema = OpSchema(
|
||||
op=cast(torch._ops.OpOverload, node.target),
|
||||
args_schema=tuple(args_schema_list),
|
||||
kwargs_schema=cast(Dict[str, object], node.kwargs),
|
||||
kwargs_schema=cast(dict[str, object], node.kwargs),
|
||||
)
|
||||
return op_schema
|
||||
|
||||
|
||||
def _shard_state_dict(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
placement_strategies: Dict[Node, PlacementStrategy],
|
||||
state_dict: dict[str, torch.Tensor],
|
||||
placement_strategies: dict[Node, PlacementStrategy],
|
||||
graph_signature: ExportGraphSignature,
|
||||
mesh: DeviceMesh,
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import warnings
|
||||
from fnmatch import fnmatch
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -16,7 +16,7 @@ __all__ = ["parallelize_module"]
|
|||
def parallelize_module( # type: ignore[return]
|
||||
module: nn.Module,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
parallelize_plan: Optional[Union[ParallelStyle, Dict[str, ParallelStyle]]] = None,
|
||||
parallelize_plan: Optional[Union[ParallelStyle, dict[str, ParallelStyle]]] = None,
|
||||
*,
|
||||
src_data_rank: Optional[int] = 0,
|
||||
) -> nn.Module:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from typing import Any, List, Optional, Set
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed.tensor.parallel._data_parallel_utils import (
|
||||
|
|
@ -23,7 +23,7 @@ def _get_submodule_n_params(module: nn.Module, path: str):
|
|||
return module, path
|
||||
|
||||
|
||||
def _update_module_param(param_list: List[tuple[nn.Module, str, nn.Parameter]]):
|
||||
def _update_module_param(param_list: list[tuple[nn.Module, str, nn.Parameter]]):
|
||||
"""
|
||||
Update parameters within the module
|
||||
"""
|
||||
|
|
@ -48,7 +48,7 @@ def _reconstruct_dtensor(module: nn.Module, _input: Any):
|
|||
|
||||
|
||||
def _localize_dtensor(
|
||||
module: nn.Module, *_: Any, ignored_params: Optional[Set[nn.Parameter]] = None
|
||||
module: nn.Module, *_: Any, ignored_params: Optional[set[nn.Parameter]] = None
|
||||
):
|
||||
"""
|
||||
Convert DTensor parameters to local tensors
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
from typing import Any, cast, List, Optional
|
||||
from typing import Any, cast, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -163,7 +163,7 @@ def _chunk_tensor(
|
|||
)
|
||||
|
||||
outer_local_shard = tensor.local_shards()[0]
|
||||
shards: List[Shard] = [
|
||||
shards: list[Shard] = [
|
||||
Shard(inner_st, copy.deepcopy(outer_local_shard.metadata))
|
||||
]
|
||||
st_meta = copy.deepcopy(tensor.metadata())
|
||||
|
|
@ -284,7 +284,7 @@ def _chunk_dtensor(
|
|||
|
||||
def _pre_load_state_dict(
|
||||
tensor: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, List[Shard]]:
|
||||
) -> tuple[torch.Tensor, list[Shard]]:
|
||||
shards = cast(ShardedTensor, tensor).local_shards()
|
||||
if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor:
|
||||
inner_tensor = shards[0].tensor
|
||||
|
|
@ -377,7 +377,7 @@ class DTensorExtensions(FSDPExtensions):
|
|||
def pre_load_state_dict_transform(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, List[Shard]]:
|
||||
) -> tuple[torch.Tensor, list[Shard]]:
|
||||
return _pre_load_state_dict(tensor)
|
||||
|
||||
def all_gather_dtensor(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import contextlib
|
||||
from typing import cast, Dict, Optional
|
||||
from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
|
|
@ -108,7 +108,7 @@ def _cast_to_dtensor(
|
|||
def _propagate_tensor_meta(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> TensorMeta:
|
||||
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
|
||||
tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta(
|
||||
|
|
@ -154,7 +154,7 @@ def _log_softmax(x, dim, half_to_float, mesh, mesh_dim):
|
|||
def _log_softmax_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
x = cast(DTensor, args[0])
|
||||
dim = cast(int, args[1])
|
||||
|
|
@ -185,7 +185,7 @@ def _log_softmax_handler(
|
|||
def _log_softmax_backward_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
grad_output = cast(DTensor, args[0])
|
||||
input_dtype = cast(torch.dtype, args[3])
|
||||
|
|
@ -270,7 +270,7 @@ def _nll_loss_forward(
|
|||
def _nll_loss_forward_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
x = cast(DTensor, args[0])
|
||||
target = args[1]
|
||||
|
|
@ -414,7 +414,7 @@ def _nll_loss_and_log_softmax_backward(
|
|||
def _nll_loss_backward_handler(
|
||||
op_call: torch._ops.OpOverload,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> object:
|
||||
grad_output = cast(DTensor, args[0])
|
||||
x = cast(DTensor, args[1])
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -453,8 +453,8 @@ class PrepareModuleInput(ParallelStyle):
|
|||
desired_input_layouts: Optional[
|
||||
Union[Placement, tuple[Optional[Placement]]]
|
||||
] = None,
|
||||
input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
|
||||
desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None,
|
||||
input_kwarg_layouts: Optional[dict[str, Placement]] = None,
|
||||
desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None,
|
||||
use_local_output: bool = False,
|
||||
):
|
||||
self.input_layouts = (
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, List, Optional
|
||||
from typing import cast, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
|
@ -73,7 +73,7 @@ class Shard(Placement):
|
|||
*,
|
||||
with_padding: bool = True,
|
||||
contiguous: bool = True,
|
||||
) -> tuple[List[torch.Tensor], List[int]]:
|
||||
) -> tuple[list[torch.Tensor], list[int]]:
|
||||
"""
|
||||
This function uses torch.chunk to split a tensor into num_chunks shards along
|
||||
the Shard placement dimension, and return a list of shards with their pad sizes.
|
||||
|
|
@ -236,7 +236,7 @@ class Shard(Placement):
|
|||
local_tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
current_logical_shape: List[int],
|
||||
current_logical_shape: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function all_gather all shards and return a tensor that
|
||||
|
|
@ -292,7 +292,7 @@ class Shard(Placement):
|
|||
local_tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
current_logical_shape: List[int],
|
||||
current_logical_shape: list[int],
|
||||
new_shard_dim: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
|
@ -464,7 +464,7 @@ class _StridedShard(Shard):
|
|||
*,
|
||||
with_padding: bool = True,
|
||||
contiguous: bool = True,
|
||||
) -> tuple[List[torch.Tensor], List[int]]:
|
||||
) -> tuple[list[torch.Tensor], list[int]]:
|
||||
"""
|
||||
TODO: currently _StridedShard does not support padding
|
||||
"""
|
||||
|
|
@ -502,7 +502,7 @@ class _StridedShard(Shard):
|
|||
local_tensor: torch.Tensor,
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int,
|
||||
current_logical_shape: List[int],
|
||||
current_logical_shape: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Note: currently _StridedShard does not support padding
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user