mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[DeviceMesh] Move DeviceMesh out from torch.distributed._tensor (#112364)
Move DeviceMesh out as a standalone module. Once we make sure everything is migrated and doc is ready, we will make `torch.distributed._device_mesh` public in follow-up PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112364 Approved by: https://github.com/wanchaol, https://github.com/fegin, https://github.com/fduwjj
This commit is contained in:
parent
6f681ab5d9
commit
b07cfd79fe
|
|
@ -36,10 +36,12 @@ time python test/run_test.py --verbose -i distributed/test_functional_api
|
|||
|
||||
|
||||
# DTensor tests
|
||||
time python test/run_test.py --verbose -i distributed/_tensor/test_device_mesh
|
||||
time python test/run_test.py --verbose -i distributed/_tensor/test_random_ops
|
||||
time python test/run_test.py --verbose -i distributed/_tensor/test_dtensor_compile
|
||||
|
||||
# DeviceMesh test
|
||||
time python test/run_test.py --verbose -i distributed/test_device_mesh
|
||||
|
||||
# DTensor/TP tests
|
||||
time python test/run_test.py --verbose -i distributed/tensor/parallel/test_ddp_2d_parallel
|
||||
time python test/run_test.py --verbose -i distributed/tensor/parallel/test_fsdp_2d_parallel
|
||||
|
|
|
|||
|
|
@ -290,6 +290,7 @@ class TestPublicBindings(TestCase):
|
|||
"torch.backends._coreml.preprocess",
|
||||
"torch.contrib._tensorboard_vis",
|
||||
"torch.distributed._composable",
|
||||
"torch.distributed._device_mesh",
|
||||
"torch.distributed._functional_collectives",
|
||||
"torch.distributed._functional_collectives_impl",
|
||||
"torch.distributed._shard",
|
||||
|
|
|
|||
|
|
@ -183,6 +183,7 @@ if torch.distributed.is_available():
|
|||
LEGACY_MOD_INLINELIST |= {
|
||||
"torch.distributed._tensor.api",
|
||||
"torch.distributed._tensor.device_mesh",
|
||||
"torch.distributed._device_mesh",
|
||||
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
|
||||
"torch.distributed.tensor.parallel._data_parallel_utils",
|
||||
"torch.distributed.tensor.parallel._utils",
|
||||
|
|
|
|||
454
torch/distributed/_device_mesh.py
Normal file
454
torch/distributed/_device_mesh.py
Normal file
|
|
@ -0,0 +1,454 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_find_pg_by_ranks_and_tag,
|
||||
_get_default_group,
|
||||
_get_group_tag,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
init_process_group,
|
||||
is_initialized,
|
||||
new_group,
|
||||
ProcessGroup,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# only import numpy typing when type checking
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from numpy.typing import ArrayLike
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"DeviceMesh requires numpy >= 1.21 to be installed for type checking"
|
||||
)
|
||||
|
||||
|
||||
class _MeshEnv:
|
||||
def __init__(self) -> None:
|
||||
self.mesh_stack: List[DeviceMesh] = []
|
||||
self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
|
||||
|
||||
def get_current_mesh(self) -> "DeviceMesh":
|
||||
if len(self.mesh_stack) == 0:
|
||||
raise RuntimeError("No device mesh is currently active!")
|
||||
return self.mesh_stack[-1]
|
||||
|
||||
def create_child_mesh(
|
||||
self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str
|
||||
) -> "DeviceMesh":
|
||||
# swap the current dim to the last dim then reshape to flatten out other
|
||||
# dims, so we can just extract the list of ranks which contains cur_rank.
|
||||
cur_rank = device_mesh.get_rank()
|
||||
pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
|
||||
-1, device_mesh.mesh.size(mesh_dim)
|
||||
)
|
||||
|
||||
for mesh_1d in pg_ranks_by_dim:
|
||||
sub_mesh = DeviceMesh(
|
||||
device_mesh.device_type,
|
||||
mesh_1d,
|
||||
mesh_dim_names=(mesh_dim_name,),
|
||||
_init_process_groups=False,
|
||||
)
|
||||
if cur_rank in mesh_1d:
|
||||
res_sub_mesh = sub_mesh
|
||||
|
||||
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
|
||||
# Assign the current DeviceMesh as the parent of the child DeviceMesh.
|
||||
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
|
||||
return res_sub_mesh
|
||||
|
||||
def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
|
||||
return self.child_to_parent_mapping.get(device_mesh, None)
|
||||
|
||||
def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
|
||||
"""
|
||||
Return the index of the mesh dim in the parent mesh.
|
||||
The device_mesh passed in needs to be sliced out from a parent mesh.
|
||||
"""
|
||||
parent_mesh = self.get_parent_mesh(device_mesh)
|
||||
child_mesh_dim_names = device_mesh.mesh_dim_names
|
||||
if parent_mesh and child_mesh_dim_names:
|
||||
assert (
|
||||
len(child_mesh_dim_names) == 1
|
||||
), "The child mesh can only be a 1D mesh."
|
||||
child_mesh_dim_name = child_mesh_dim_names[0]
|
||||
if parent_mesh.mesh_dim_names:
|
||||
return parent_mesh.mesh_dim_names.index(child_mesh_dim_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def num_devices_per_host(device_type: str) -> int:
|
||||
return _get_device_handle(device_type).device_count()
|
||||
|
||||
@staticmethod
|
||||
def num_hosts(device_type: str) -> int:
|
||||
# ProcessGroup can't tell us this info so we have to infer it, assume
|
||||
# homogeneous hardware for now
|
||||
return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
|
||||
|
||||
|
||||
_mesh_resources: _MeshEnv = _MeshEnv()
|
||||
|
||||
|
||||
def _get_device_handle(device_type: str = "cuda"):
|
||||
"""
|
||||
Get the module corresponding to the device_type which is cuda or cuda-like device.
|
||||
For example, when the device_type is cuda, the module `torch.cuda` is returned.
|
||||
Return None when there is no corresponding module for device_type, otherwise
|
||||
return the corresponding module.
|
||||
"""
|
||||
return getattr(torch, device_type, None)
|
||||
|
||||
|
||||
class DeviceMesh:
|
||||
"""
|
||||
DeviceMesh represents a mesh of devices, where layout of devices could be
|
||||
represented as a n-d dimension array, and each value of the n-d dimensional
|
||||
array is the global id of the default process group ranks.
|
||||
|
||||
DeviceMesh could be used to describe the layout of devices across the cluster,
|
||||
and serves as a proxy for communication among the device lists within the cluster.
|
||||
|
||||
We use the default ProcessGroup in this DeviceMesh class to implement proper
|
||||
communications. Note that we also add collective wrappers in this class. This is
|
||||
used to decouple detailed communication backend with the underlying
|
||||
DTensor implementation.
|
||||
|
||||
DeviceMesh can be used as a context manager.
|
||||
Args:
|
||||
device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
|
||||
mesh (ndarray): could be a multi-dimension array or an integer tensor that
|
||||
describes the layout of devices, the ids are global ids of the
|
||||
default process group.
|
||||
|
||||
Returns:
|
||||
A :class:`DeviceMesh` object
|
||||
|
||||
Example (2 host with 4 GPUs each):
|
||||
```
|
||||
# The following program runs on each process/rank in SPMD manner.
|
||||
# initialize device mesh as (2, 4) to represent the topology
|
||||
# of cross-host(dim 0), and within-host (dim 1)
|
||||
mesh = DeviceMesh(device_type="cuda",
|
||||
mesh=[
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7]
|
||||
])
|
||||
```
|
||||
A reduction over the first dimension of mesh will reduce across
|
||||
columns (0, 4), .. and (3, 7), a reduction over the second dimension
|
||||
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7)
|
||||
|
||||
"""
|
||||
|
||||
device_type: str
|
||||
mesh: torch.Tensor
|
||||
mesh_dim_names: Optional[Tuple[str, ...]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_type: str,
|
||||
mesh: Union[torch.Tensor, "ArrayLike"],
|
||||
*,
|
||||
mesh_dim_names: Optional[Tuple[str, ...]] = None,
|
||||
_init_process_groups: bool = True,
|
||||
_validate_mesh: bool = True,
|
||||
) -> None:
|
||||
self.device_type = device_type
|
||||
self.mesh = (
|
||||
mesh.detach()
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, dtype=torch.int)
|
||||
)
|
||||
self.mesh_dim_names = mesh_dim_names
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
self._hash = hash((self._flatten_mesh_list, self.mesh.shape))
|
||||
|
||||
# Skip process group initialization if xla device.
|
||||
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
||||
if device_type != "xla":
|
||||
# always try to create default (world) pg, even if it is not initialized
|
||||
# already. The world pg is used for device mesh identity (rank) on each
|
||||
# process (we need to know if the current global rank is in the mesh or not).
|
||||
self._get_or_create_default_group()
|
||||
if _init_process_groups:
|
||||
self._init_process_groups(_validate_mesh)
|
||||
|
||||
def _get_or_create_default_group(self):
|
||||
default_initialized = is_initialized()
|
||||
if not default_initialized:
|
||||
init_process_group()
|
||||
|
||||
world_size = get_world_size()
|
||||
if self.mesh.numel() > world_size:
|
||||
raise RuntimeError(
|
||||
f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
|
||||
)
|
||||
|
||||
device_handle = _get_device_handle(self.device_type)
|
||||
# TODO: if user want to pass pg_options, offer a way to do it
|
||||
if not default_initialized and device_handle:
|
||||
# automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
|
||||
# NOTE: This device selection would only work for homogeneous hardware.
|
||||
num_devices_per_host = device_handle.device_count()
|
||||
if (
|
||||
world_size > num_devices_per_host
|
||||
and world_size % num_devices_per_host != 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"DeviceMesh only support homogeneous hardware, but found "
|
||||
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
|
||||
)
|
||||
device_handle.set_device(get_rank() % num_devices_per_host)
|
||||
|
||||
# calculate the coordinates of the current global rank on the mesh
|
||||
rank_coords = (self.mesh == get_rank()).nonzero()
|
||||
assert rank_coords.size(0) in (0, 1)
|
||||
self._coordinate_on_dim: Optional[List[int]] = (
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
return _get_default_group()
|
||||
|
||||
def _validate_mesh(self):
|
||||
# check mesh tensor validity
|
||||
unique_mesh_values = self.mesh.unique(sorted=True)
|
||||
if unique_mesh_values.numel() != self.mesh.numel():
|
||||
raise RuntimeError(
|
||||
f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
|
||||
)
|
||||
|
||||
# validate that all calling ranks pass in the same `mesh` argument.
|
||||
self_mesh = self.mesh.to(self.device_type).contiguous()
|
||||
mesh_tensor = funcol.all_gather_tensor(
|
||||
self_mesh, gather_dim=0, group=_get_default_group()
|
||||
)
|
||||
mesh_tensor_chunked = torch.chunk(mesh_tensor, get_world_size())
|
||||
for other_rank, other_mesh in enumerate(mesh_tensor_chunked):
|
||||
if not torch.equal(self_mesh, other_mesh):
|
||||
raise RuntimeError(
|
||||
f"DeviceMesh initialization does not allow different mesh argument:"
|
||||
f"rank {get_rank()} has mesh {self_mesh} while rank {other_rank}"
|
||||
f"has mesh {other_mesh}!"
|
||||
)
|
||||
|
||||
def _init_process_groups(self, _validate_mesh):
|
||||
if _validate_mesh:
|
||||
self._validate_mesh()
|
||||
|
||||
# group tag/ranks associated with each mesh dimension, each mesh dimension should
|
||||
# have one sub-group per rank
|
||||
dim_group_infos: List[Tuple[str, List[int]]] = []
|
||||
|
||||
if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
|
||||
# if the mesh is the same as world_pg, we just append the default
|
||||
# pg to the first dim groups, as new_group cannot have the exact
|
||||
# same ranks as world
|
||||
dim_group_infos.append(
|
||||
(_get_group_tag(_get_default_group()), list(range(get_world_size())))
|
||||
)
|
||||
else:
|
||||
# create sub pgs base on the mesh argument specified
|
||||
for dim in range(self.mesh.ndim):
|
||||
# swap the current dim to the last dim
|
||||
# then reshape to flatten out other dims
|
||||
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
|
||||
-1, self.mesh.size(dim)
|
||||
)
|
||||
# multi-dim mesh, create subgroups by looping over the pg_ranks
|
||||
# for each dim and append the groups
|
||||
for dim_mesh in pg_ranks_by_dim:
|
||||
subgroup_ranks = dim_mesh.tolist()
|
||||
# call new_group regardless of the current rank in the
|
||||
# pg or not, it's required that all ranks participate
|
||||
# in subgroup construction
|
||||
dim_group = new_group(ranks=subgroup_ranks)
|
||||
# only add to dim_groups if the current rank in the subgroup
|
||||
if self.get_rank() in subgroup_ranks:
|
||||
if len(dim_group_infos) > dim:
|
||||
raise RuntimeError(
|
||||
f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
|
||||
f"in {subgroup_ranks}!"
|
||||
)
|
||||
dim_group_infos.append(
|
||||
(_get_group_tag(dim_group), subgroup_ranks)
|
||||
)
|
||||
self._dim_group_infos = dim_group_infos
|
||||
|
||||
def __enter__(self) -> "DeviceMesh":
|
||||
# set this mesh as the current mesh in mesh env
|
||||
_mesh_resources.mesh_stack.append(self)
|
||||
return self
|
||||
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
|
||||
# pop this mesh from mesh env
|
||||
_mesh_resources.mesh_stack.pop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DeviceMesh:({self.mesh.tolist()})"
|
||||
|
||||
def __hash__(self):
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, DeviceMesh):
|
||||
return False
|
||||
if id(self.mesh) == id(other.mesh):
|
||||
return True
|
||||
return (
|
||||
self.mesh.shape == other.mesh.shape
|
||||
and self._flatten_mesh_list == other._flatten_mesh_list
|
||||
)
|
||||
|
||||
def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
|
||||
"""
|
||||
Slice the current DeviceMesh based on the mesh_dim_name given to create a child
|
||||
DeviceMesh.
|
||||
|
||||
Args:
|
||||
mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
|
||||
to create a child DeviceMesh for.
|
||||
Returns:
|
||||
A :class:`DeviceMesh` object
|
||||
|
||||
Example (2 host with 4 GPUs each):
|
||||
```
|
||||
# Below is a DeviceMesh with mesh_shape of (2, 4) and mesh_dim_name of ("dp", "tp")
|
||||
mesh = DeviceMesh(device_type="cuda",
|
||||
mesh=[
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7]
|
||||
],
|
||||
mesh_dim_names=["dp", "tp"])
|
||||
)
|
||||
```
|
||||
Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
|
||||
Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
|
||||
Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
|
||||
Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]).
|
||||
Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]).
|
||||
Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]).
|
||||
"""
|
||||
if self.mesh.ndim <= 1:
|
||||
raise RuntimeError(
|
||||
f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension."
|
||||
)
|
||||
if self.mesh_dim_names is None:
|
||||
raise KeyError(
|
||||
"No `mesh_dim_names` found.",
|
||||
"To slice the device mesh, please call `init_device_mesh` with `mesh_dim_names`.",
|
||||
)
|
||||
if mesh_dim_name not in self.mesh_dim_names:
|
||||
raise KeyError(
|
||||
f"Mesh dimension '{mesh_dim_name}' does not exist.",
|
||||
f"Available mesh dimensions are: {self.mesh_dim_names}",
|
||||
)
|
||||
mesh_dim = self.mesh_dim_names.index(mesh_dim_name)
|
||||
submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)
|
||||
|
||||
return submesh
|
||||
|
||||
def get_dim_groups(
|
||||
self, mesh_dim: Optional[int] = None
|
||||
) -> Union[ProcessGroup, List[ProcessGroup]]:
|
||||
if not hasattr(self, "_dim_group_infos"):
|
||||
raise RuntimeError("DeviceMesh process groups not initialized!")
|
||||
if mesh_dim is not None:
|
||||
return _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
|
||||
else:
|
||||
dim_groups = []
|
||||
for mesh_dim in range(self.mesh.ndim):
|
||||
dim_groups.append(
|
||||
_find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
|
||||
)
|
||||
return dim_groups
|
||||
|
||||
def size(self, dim: Optional[int] = None) -> int:
|
||||
return self.mesh.numel() if dim is None else self.mesh.size(dim)
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
return self.mesh.ndim
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
return tuple(self.mesh.shape)
|
||||
|
||||
def get_rank(self) -> int:
|
||||
return get_rank()
|
||||
|
||||
def get_coordinate(self) -> Optional[List[int]]:
|
||||
"""
|
||||
Return the relative indices of this rank relative to all
|
||||
dimensions of the mesh. If this rank is not part of the mesh, return None.
|
||||
"""
|
||||
return self._coordinate_on_dim if self._coordinate_on_dim else None
|
||||
|
||||
|
||||
def init_device_mesh(
|
||||
device_type: str,
|
||||
mesh_shape: Tuple[int, ...],
|
||||
*,
|
||||
mesh_dim_names: Optional[Tuple[str, ...]] = None,
|
||||
) -> DeviceMesh:
|
||||
"""
|
||||
Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
|
||||
This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape)
|
||||
and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is
|
||||
labeled as mesh_dim_names[i].
|
||||
|
||||
|
||||
Args:
|
||||
device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
|
||||
mesh_shape: Tuple[int]: A tuple describes the dimension of the multi-dimesnion array
|
||||
that describes the layout of devices.
|
||||
Kwargs:
|
||||
mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension
|
||||
of the multi-dimensional array that describes the layout of devices. Its length must match the length
|
||||
of `mesh_shape`. Each string in mesh_dim_names must be unique.
|
||||
|
||||
Returns:
|
||||
A :class:`DeviceMesh` object
|
||||
|
||||
.. note: If no process group is found, init_device_mesh will initialize distributed process group/groups
|
||||
behind the scene, which are required for distributed communications.
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> from torch.distributed._tensor.device_mesh import init_device_mesh
|
||||
>>>
|
||||
>>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
|
||||
>>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
|
||||
"""
|
||||
if mesh_dim_names is not None:
|
||||
if len(set(mesh_dim_names)) != len(mesh_dim_names):
|
||||
raise RuntimeError(
|
||||
"Each mesh_dim_name must be uqique.",
|
||||
f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
|
||||
)
|
||||
|
||||
if len(mesh_shape) != len(mesh_dim_names):
|
||||
raise RuntimeError(
|
||||
"mesh_shape and mesh_dim_names should have same length!",
|
||||
f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
|
||||
)
|
||||
|
||||
mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape)
|
||||
device_mesh = DeviceMesh(
|
||||
device_type=device_type,
|
||||
mesh=mesh,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
)
|
||||
|
||||
return device_mesh
|
||||
|
|
@ -1,454 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_find_pg_by_ranks_and_tag,
|
||||
_get_default_group,
|
||||
_get_group_tag,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
init_process_group,
|
||||
is_initialized,
|
||||
new_group,
|
||||
ProcessGroup,
|
||||
from torch.distributed._device_mesh import ( # noqa: F401
|
||||
_get_device_handle,
|
||||
_mesh_resources,
|
||||
DeviceMesh,
|
||||
init_device_mesh,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# only import numpy typing when type checking
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from numpy.typing import ArrayLike
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"DeviceMesh requires numpy >= 1.21 to be installed for type checking"
|
||||
)
|
||||
|
||||
|
||||
class _MeshEnv:
|
||||
def __init__(self) -> None:
|
||||
self.mesh_stack: List[DeviceMesh] = []
|
||||
self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
|
||||
|
||||
def get_current_mesh(self) -> "DeviceMesh":
|
||||
if len(self.mesh_stack) == 0:
|
||||
raise RuntimeError("No device mesh is currently active!")
|
||||
return self.mesh_stack[-1]
|
||||
|
||||
def create_child_mesh(
|
||||
self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str
|
||||
) -> "DeviceMesh":
|
||||
# swap the current dim to the last dim then reshape to flatten out other
|
||||
# dims, so we can just extract the list of ranks which contains cur_rank.
|
||||
cur_rank = device_mesh.get_rank()
|
||||
pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
|
||||
-1, device_mesh.mesh.size(mesh_dim)
|
||||
)
|
||||
|
||||
for mesh_1d in pg_ranks_by_dim:
|
||||
sub_mesh = DeviceMesh(
|
||||
device_mesh.device_type,
|
||||
mesh_1d,
|
||||
mesh_dim_names=(mesh_dim_name,),
|
||||
_init_process_groups=False,
|
||||
)
|
||||
if cur_rank in mesh_1d:
|
||||
res_sub_mesh = sub_mesh
|
||||
|
||||
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
|
||||
# Assign the current DeviceMesh as the parent of the child DeviceMesh.
|
||||
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
|
||||
return res_sub_mesh
|
||||
|
||||
def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
|
||||
return self.child_to_parent_mapping.get(device_mesh, None)
|
||||
|
||||
def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
|
||||
"""
|
||||
Return the index of the mesh dim in the parent mesh.
|
||||
The device_mesh passed in needs to be sliced out from a parent mesh.
|
||||
"""
|
||||
parent_mesh = self.get_parent_mesh(device_mesh)
|
||||
child_mesh_dim_names = device_mesh.mesh_dim_names
|
||||
if parent_mesh and child_mesh_dim_names:
|
||||
assert (
|
||||
len(child_mesh_dim_names) == 1
|
||||
), "The child mesh can only be a 1D mesh."
|
||||
child_mesh_dim_name = child_mesh_dim_names[0]
|
||||
if parent_mesh.mesh_dim_names:
|
||||
return parent_mesh.mesh_dim_names.index(child_mesh_dim_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def num_devices_per_host(device_type: str) -> int:
|
||||
return _get_device_handle(device_type).device_count()
|
||||
|
||||
@staticmethod
|
||||
def num_hosts(device_type: str) -> int:
|
||||
# ProcessGroup can't tell us this info so we have to infer it, assume
|
||||
# homogeneous hardware for now
|
||||
return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
|
||||
|
||||
|
||||
_mesh_resources: _MeshEnv = _MeshEnv()
|
||||
|
||||
|
||||
def _get_device_handle(device_type: str = "cuda"):
|
||||
"""
|
||||
Get the module corresponding to the device_type which is cuda or cuda-like device.
|
||||
For example, when the device_type is cuda, the module `torch.cuda` is returned.
|
||||
Return None when there is no corresponding module for device_type, otherwise
|
||||
return the corresponding module.
|
||||
"""
|
||||
return getattr(torch, device_type, None)
|
||||
|
||||
|
||||
class DeviceMesh:
|
||||
"""
|
||||
DeviceMesh represents a mesh of devices, where layout of devices could be
|
||||
represented as a n-d dimension array, and each value of the n-d dimensional
|
||||
array is the global id of the default process group ranks.
|
||||
|
||||
DeviceMesh could be used to describe the layout of devices across the cluster,
|
||||
and serves as a proxy for communication among the device lists within the cluster.
|
||||
|
||||
We use the default ProcessGroup in this DeviceMesh class to implement proper
|
||||
communications. Note that we also add collective wrappers in this class. This is
|
||||
used to decouple detailed communication backend with the underlying
|
||||
DTensor implementation.
|
||||
|
||||
DeviceMesh can be used as a context manager.
|
||||
Args:
|
||||
device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
|
||||
mesh (ndarray): could be a multi-dimension array or an integer tensor that
|
||||
describes the layout of devices, the ids are global ids of the
|
||||
default process group.
|
||||
|
||||
Returns:
|
||||
A :class:`DeviceMesh` object
|
||||
|
||||
Example (2 host with 4 GPUs each):
|
||||
```
|
||||
# The following program runs on each process/rank in SPMD manner.
|
||||
# initialize device mesh as (2, 4) to represent the topology
|
||||
# of cross-host(dim 0), and within-host (dim 1)
|
||||
mesh = DeviceMesh(device_type="cuda",
|
||||
mesh=[
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7]
|
||||
])
|
||||
```
|
||||
A reduction over the first dimension of mesh will reduce across
|
||||
columns (0, 4), .. and (3, 7), a reduction over the second dimension
|
||||
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7)
|
||||
|
||||
"""
|
||||
|
||||
device_type: str
|
||||
mesh: torch.Tensor
|
||||
mesh_dim_names: Optional[Tuple[str, ...]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_type: str,
|
||||
mesh: Union[torch.Tensor, "ArrayLike"],
|
||||
*,
|
||||
mesh_dim_names: Optional[Tuple[str, ...]] = None,
|
||||
_init_process_groups: bool = True,
|
||||
_validate_mesh: bool = True,
|
||||
) -> None:
|
||||
self.device_type = device_type
|
||||
self.mesh = (
|
||||
mesh.detach()
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, dtype=torch.int)
|
||||
)
|
||||
self.mesh_dim_names = mesh_dim_names
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
self._hash = hash((self._flatten_mesh_list, self.mesh.shape))
|
||||
|
||||
# Skip process group initialization if xla device.
|
||||
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
||||
if device_type != "xla":
|
||||
# always try to create default (world) pg, even if it is not initialized
|
||||
# already. The world pg is used for device mesh identity (rank) on each
|
||||
# process (we need to know if the current global rank is in the mesh or not).
|
||||
self._get_or_create_default_group()
|
||||
if _init_process_groups:
|
||||
self._init_process_groups(_validate_mesh)
|
||||
|
||||
def _get_or_create_default_group(self):
|
||||
default_initialized = is_initialized()
|
||||
if not default_initialized:
|
||||
init_process_group()
|
||||
|
||||
world_size = get_world_size()
|
||||
if self.mesh.numel() > world_size:
|
||||
raise RuntimeError(
|
||||
f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
|
||||
)
|
||||
|
||||
device_handle = _get_device_handle(self.device_type)
|
||||
# TODO: if user want to pass pg_options, offer a way to do it
|
||||
if not default_initialized and device_handle:
|
||||
# automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
|
||||
# NOTE: This device selection would only work for homogeneous hardware.
|
||||
num_devices_per_host = device_handle.device_count()
|
||||
if (
|
||||
world_size > num_devices_per_host
|
||||
and world_size % num_devices_per_host != 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"DeviceMesh only support homogeneous hardware, but found "
|
||||
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
|
||||
)
|
||||
device_handle.set_device(get_rank() % num_devices_per_host)
|
||||
|
||||
# calculate the coordinates of the current global rank on the mesh
|
||||
rank_coords = (self.mesh == get_rank()).nonzero()
|
||||
assert rank_coords.size(0) in (0, 1)
|
||||
self._coordinate_on_dim: Optional[List[int]] = (
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
return _get_default_group()
|
||||
|
||||
def _validate_mesh(self):
|
||||
# check mesh tensor validity
|
||||
unique_mesh_values = self.mesh.unique(sorted=True)
|
||||
if unique_mesh_values.numel() != self.mesh.numel():
|
||||
raise RuntimeError(
|
||||
f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
|
||||
)
|
||||
|
||||
# validate that all calling ranks pass in the same `mesh` argument.
|
||||
self_mesh = self.mesh.to(self.device_type).contiguous()
|
||||
mesh_tensor = funcol.all_gather_tensor(
|
||||
self_mesh, gather_dim=0, group=_get_default_group()
|
||||
)
|
||||
mesh_tensor_chunked = torch.chunk(mesh_tensor, get_world_size())
|
||||
for other_rank, other_mesh in enumerate(mesh_tensor_chunked):
|
||||
if not torch.equal(self_mesh, other_mesh):
|
||||
raise RuntimeError(
|
||||
f"DeviceMesh initialization does not allow different mesh argument:"
|
||||
f"rank {get_rank()} has mesh {self_mesh} while rank {other_rank}"
|
||||
f"has mesh {other_mesh}!"
|
||||
)
|
||||
|
||||
def _init_process_groups(self, _validate_mesh):
|
||||
if _validate_mesh:
|
||||
self._validate_mesh()
|
||||
|
||||
# group tag/ranks associated with each mesh dimension, each mesh dimension should
|
||||
# have one sub-group per rank
|
||||
dim_group_infos: List[Tuple[str, List[int]]] = []
|
||||
|
||||
if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
|
||||
# if the mesh is the same as world_pg, we just append the default
|
||||
# pg to the first dim groups, as new_group cannot have the exact
|
||||
# same ranks as world
|
||||
dim_group_infos.append(
|
||||
(_get_group_tag(_get_default_group()), list(range(get_world_size())))
|
||||
)
|
||||
else:
|
||||
# create sub pgs base on the mesh argument specified
|
||||
for dim in range(self.mesh.ndim):
|
||||
# swap the current dim to the last dim
|
||||
# then reshape to flatten out other dims
|
||||
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
|
||||
-1, self.mesh.size(dim)
|
||||
)
|
||||
# multi-dim mesh, create subgroups by looping over the pg_ranks
|
||||
# for each dim and append the groups
|
||||
for dim_mesh in pg_ranks_by_dim:
|
||||
subgroup_ranks = dim_mesh.tolist()
|
||||
# call new_group regardless of the current rank in the
|
||||
# pg or not, it's required that all ranks participate
|
||||
# in subgroup construction
|
||||
dim_group = new_group(ranks=subgroup_ranks)
|
||||
# only add to dim_groups if the current rank in the subgroup
|
||||
if self.get_rank() in subgroup_ranks:
|
||||
if len(dim_group_infos) > dim:
|
||||
raise RuntimeError(
|
||||
f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
|
||||
f"in {subgroup_ranks}!"
|
||||
)
|
||||
dim_group_infos.append(
|
||||
(_get_group_tag(dim_group), subgroup_ranks)
|
||||
)
|
||||
self._dim_group_infos = dim_group_infos
|
||||
|
||||
def __enter__(self) -> "DeviceMesh":
|
||||
# set this mesh as the current mesh in mesh env
|
||||
_mesh_resources.mesh_stack.append(self)
|
||||
return self
|
||||
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
|
||||
# pop this mesh from mesh env
|
||||
_mesh_resources.mesh_stack.pop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DeviceMesh:({self.mesh.tolist()})"
|
||||
|
||||
def __hash__(self):
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, DeviceMesh):
|
||||
return False
|
||||
if id(self.mesh) == id(other.mesh):
|
||||
return True
|
||||
return (
|
||||
self.mesh.shape == other.mesh.shape
|
||||
and self._flatten_mesh_list == other._flatten_mesh_list
|
||||
)
|
||||
|
||||
def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
|
||||
"""
|
||||
Slice the current DeviceMesh based on the mesh_dim_name given to create a child
|
||||
DeviceMesh.
|
||||
|
||||
Args:
|
||||
mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
|
||||
to create a child DeviceMesh for.
|
||||
Returns:
|
||||
A :class:`DeviceMesh` object
|
||||
|
||||
Example (2 host with 4 GPUs each):
|
||||
```
|
||||
# Below is a DeviceMesh with mesh_shape of (2, 4) and mesh_dim_name of ("dp", "tp")
|
||||
mesh = DeviceMesh(device_type="cuda",
|
||||
mesh=[
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7]
|
||||
],
|
||||
mesh_dim_names=["dp", "tp"])
|
||||
)
|
||||
```
|
||||
Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
|
||||
Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
|
||||
Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
|
||||
Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]).
|
||||
Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]).
|
||||
Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]).
|
||||
"""
|
||||
if self.mesh.ndim <= 1:
|
||||
raise RuntimeError(
|
||||
f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension."
|
||||
)
|
||||
if self.mesh_dim_names is None:
|
||||
raise KeyError(
|
||||
"No `mesh_dim_names` found.",
|
||||
"To slice the device mesh, please call `init_device_mesh` with `mesh_dim_names`.",
|
||||
)
|
||||
if mesh_dim_name not in self.mesh_dim_names:
|
||||
raise KeyError(
|
||||
f"Mesh dimension '{mesh_dim_name}' does not exist.",
|
||||
f"Available mesh dimensions are: {self.mesh_dim_names}",
|
||||
)
|
||||
mesh_dim = self.mesh_dim_names.index(mesh_dim_name)
|
||||
submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)
|
||||
|
||||
return submesh
|
||||
|
||||
def get_dim_groups(
|
||||
self, mesh_dim: Optional[int] = None
|
||||
) -> Union[ProcessGroup, List[ProcessGroup]]:
|
||||
if not hasattr(self, "_dim_group_infos"):
|
||||
raise RuntimeError("DeviceMesh process groups not initialized!")
|
||||
if mesh_dim is not None:
|
||||
return _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
|
||||
else:
|
||||
dim_groups = []
|
||||
for mesh_dim in range(self.mesh.ndim):
|
||||
dim_groups.append(
|
||||
_find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
|
||||
)
|
||||
return dim_groups
|
||||
|
||||
def size(self, dim: Optional[int] = None) -> int:
|
||||
return self.mesh.numel() if dim is None else self.mesh.size(dim)
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
return self.mesh.ndim
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
return tuple(self.mesh.shape)
|
||||
|
||||
def get_rank(self) -> int:
|
||||
return get_rank()
|
||||
|
||||
def get_coordinate(self) -> Optional[List[int]]:
|
||||
"""
|
||||
Return the relative indices of this rank relative to all
|
||||
dimensions of the mesh. If this rank is not part of the mesh, return None.
|
||||
"""
|
||||
return self._coordinate_on_dim if self._coordinate_on_dim else None
|
||||
|
||||
|
||||
def init_device_mesh(
|
||||
device_type: str,
|
||||
mesh_shape: Tuple[int, ...],
|
||||
*,
|
||||
mesh_dim_names: Optional[Tuple[str, ...]] = None,
|
||||
) -> DeviceMesh:
|
||||
"""
|
||||
Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
|
||||
This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape)
|
||||
and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is
|
||||
labeled as mesh_dim_names[i].
|
||||
|
||||
|
||||
Args:
|
||||
device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
|
||||
mesh_shape: Tuple[int]: A tuple describes the dimension of the multi-dimesnion array
|
||||
that describes the layout of devices.
|
||||
Kwargs:
|
||||
mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension
|
||||
of the multi-dimensional array that describes the layout of devices. Its length must match the length
|
||||
of `mesh_shape`. Each string in mesh_dim_names must be unique.
|
||||
|
||||
Returns:
|
||||
A :class:`DeviceMesh` object
|
||||
|
||||
.. note: If no process group is found, init_device_mesh will initialize distributed process group/groups
|
||||
behind the scene, which are required for distributed communications.
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> from torch.distributed._tensor.device_mesh import init_device_mesh
|
||||
>>>
|
||||
>>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
|
||||
>>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
|
||||
"""
|
||||
if mesh_dim_names is not None:
|
||||
if len(set(mesh_dim_names)) != len(mesh_dim_names):
|
||||
raise RuntimeError(
|
||||
"Each mesh_dim_name must be uqique.",
|
||||
f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
|
||||
)
|
||||
|
||||
if len(mesh_shape) != len(mesh_dim_names):
|
||||
raise RuntimeError(
|
||||
"mesh_shape and mesh_dim_names should have same length!",
|
||||
f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
|
||||
)
|
||||
|
||||
mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape)
|
||||
device_mesh = DeviceMesh(
|
||||
device_type=device_type,
|
||||
mesh=mesh,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
)
|
||||
|
||||
return device_mesh
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user