mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Rename _device_mesh.py to device_mesh.py, update all callsites, add documentation. We created stubs for public class and methods in torch.distributed.device_mesh so that torch.distributed.device_mesh can be imported with or without distributed is available(). Original diff reverted: D51629761 Original PR reverted: https://github.com/pytorch/pytorch/pull/115099 Prior to landing, CI signals are all passed. Shipit added the "ci/trunk" label to the PR and DID NOT wait for it and went ahead committing. More context can be found in the reverted PR above. Test Plan: CI. Differential Revision: D51861018 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115193 Approved by: https://github.com/fegin
This commit is contained in:
parent
6c585de076
commit
23fa9621e4
|
|
@ -177,13 +177,15 @@ Initialization
|
|||
--------------
|
||||
|
||||
The package needs to be initialized using the :func:`torch.distributed.init_process_group`
|
||||
function before calling any other methods. This blocks until all processes have
|
||||
joined.
|
||||
or :func:`torch.distributed.device_mesh.init_device_mesh` function before calling any other methods.
|
||||
Both block until all processes have joined.
|
||||
|
||||
.. autofunction:: is_available
|
||||
|
||||
.. autofunction:: init_process_group
|
||||
|
||||
.. autofunction:: torch.distributed.device_mesh.init_device_mesh
|
||||
|
||||
.. autofunction:: is_initialized
|
||||
|
||||
.. autofunction:: is_mpi_available
|
||||
|
|
@ -339,6 +341,18 @@ an opaque group handle that can be given as a ``group`` argument to all collecti
|
|||
|
||||
.. autofunction:: get_process_group_ranks
|
||||
|
||||
|
||||
DeviceMesh
|
||||
----------
|
||||
|
||||
DeviceMesh is a higher level abstraction that manages process groups (or NCCL communicators).
|
||||
It allows user to easily create inter node and intra node process groups without worrying about
|
||||
how to set up the ranks correctly for different sub process groups, and it helps manage those
|
||||
distributed process group easily. :func:`~torch.distributed.device_mesh.init_device_mesh` function can be
|
||||
used to create new DeviceMesh, with a mesh shape describing the device topology.
|
||||
|
||||
.. autoclass:: torch.distributed.device_mesh.DeviceMesh
|
||||
|
||||
Point-to-point communication
|
||||
----------------------------
|
||||
|
||||
|
|
@ -849,6 +863,7 @@ If you are running single node training, it may be convenient to interactively b
|
|||
.. py:module:: torch.distributed.checkpoint.utils
|
||||
.. py:module:: torch.distributed.collective_utils
|
||||
.. py:module:: torch.distributed.constants
|
||||
.. py:module:: torch.distributed.device_mesh
|
||||
.. py:module:: torch.distributed.distributed_c10d
|
||||
.. py:module:: torch.distributed.elastic.agent.server.api
|
||||
.. py:module:: torch.distributed.elastic.agent.server.local_elastic_agent
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from torch.distributed._tensor._utils import (
|
|||
compute_local_shape,
|
||||
compute_local_shape_and_global_offset,
|
||||
)
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.placement_types import Replicate, Shard
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
|
||||
from torch.distributed._tensor import DTensor, Shard
|
||||
from torch.distributed._tensor.device_mesh import init_device_mesh
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.api import (
|
||||
ShardedOptimStateDictConfig,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import torch.nn as nn
|
|||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
|
||||
from torch.distributed._tensor import DTensor, Replicate, Shard
|
||||
from torch.distributed._tensor.device_mesh import _mesh_resources, init_device_mesh
|
||||
from torch.distributed.device_mesh import _mesh_resources, init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.api import (
|
||||
ShardedOptimStateDictConfig,
|
||||
|
|
|
|||
|
|
@ -9,12 +9,8 @@ from torch.distributed._tensor._collective_utils import (
|
|||
mesh_broadcast,
|
||||
mesh_scatter,
|
||||
)
|
||||
from torch.distributed._tensor.device_mesh import (
|
||||
_mesh_resources,
|
||||
DeviceMesh,
|
||||
init_device_mesh,
|
||||
)
|
||||
from torch.distributed._tensor.placement_types import Shard
|
||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
|
||||
|
||||
from torch.distributed.distributed_c10d import (
|
||||
get_global_rank,
|
||||
|
|
|
|||
|
|
@ -292,7 +292,6 @@ class TestPublicBindings(TestCase):
|
|||
"torch.backends._coreml.preprocess",
|
||||
"torch.contrib._tensorboard_vis",
|
||||
"torch.distributed._composable",
|
||||
"torch.distributed._device_mesh",
|
||||
"torch.distributed._functional_collectives",
|
||||
"torch.distributed._functional_collectives_impl",
|
||||
"torch.distributed._shard",
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ if torch.distributed.is_available():
|
|||
LEGACY_MOD_INLINELIST |= {
|
||||
"torch.distributed._tensor.api",
|
||||
"torch.distributed._tensor.device_mesh",
|
||||
"torch.distributed._device_mesh",
|
||||
"torch.distributed.device_mesh",
|
||||
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
|
||||
"torch.distributed.tensor.parallel._data_parallel_utils",
|
||||
"torch.distributed.tensor.parallel._utils",
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ class DeviceMeshVariable(DistributedVariable):
|
|||
if not DistributedVariable.is_available():
|
||||
return False
|
||||
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
return istype(value, DeviceMesh)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,505 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_find_pg_by_ranks_and_tag,
|
||||
_get_default_group,
|
||||
_get_group_tag,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
init_process_group,
|
||||
is_initialized,
|
||||
new_group,
|
||||
ProcessGroup,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# only import numpy typing when type checking
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from numpy.typing import ArrayLike
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"DeviceMesh requires numpy >= 1.21 to be installed for type checking"
|
||||
)
|
||||
|
||||
|
||||
class _MeshEnv:
|
||||
def __init__(self) -> None:
|
||||
self.mesh_stack: List[DeviceMesh] = []
|
||||
self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
|
||||
|
||||
def get_current_mesh(self) -> "DeviceMesh":
|
||||
if len(self.mesh_stack) == 0:
|
||||
raise RuntimeError("No device mesh is currently active!")
|
||||
return self.mesh_stack[-1]
|
||||
|
||||
def create_child_mesh(
|
||||
self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str
|
||||
) -> "DeviceMesh":
|
||||
# swap the current dim to the last dim then reshape to flatten out other
|
||||
# dims, so we can just extract the list of ranks which contains cur_rank.
|
||||
cur_rank = device_mesh.get_rank()
|
||||
pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
|
||||
-1, device_mesh.mesh.size(mesh_dim)
|
||||
)
|
||||
|
||||
for mesh_1d in pg_ranks_by_dim:
|
||||
sub_mesh = DeviceMesh(
|
||||
device_mesh.device_type,
|
||||
mesh_1d,
|
||||
mesh_dim_names=(mesh_dim_name,),
|
||||
_init_process_groups=False,
|
||||
)
|
||||
if cur_rank in mesh_1d:
|
||||
res_sub_mesh = sub_mesh
|
||||
|
||||
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
|
||||
# Assign the current DeviceMesh as the parent of the child DeviceMesh.
|
||||
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
|
||||
return res_sub_mesh
|
||||
|
||||
def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
|
||||
return self.child_to_parent_mapping.get(device_mesh, None)
|
||||
|
||||
def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
|
||||
"""
|
||||
Return the index of the mesh dim in the parent mesh.
|
||||
The device_mesh passed in needs to be sliced out from a parent mesh.
|
||||
"""
|
||||
parent_mesh = self.get_parent_mesh(device_mesh)
|
||||
child_mesh_dim_names = device_mesh.mesh_dim_names
|
||||
if parent_mesh and child_mesh_dim_names:
|
||||
assert (
|
||||
len(child_mesh_dim_names) == 1
|
||||
), "The child mesh can only be a 1D mesh."
|
||||
child_mesh_dim_name = child_mesh_dim_names[0]
|
||||
if parent_mesh.mesh_dim_names:
|
||||
return parent_mesh._get_mesh_dim_by_name(child_mesh_dim_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def num_devices_per_host(device_type: str) -> int:
|
||||
return _get_device_handle(device_type).device_count()
|
||||
|
||||
@staticmethod
|
||||
def num_hosts(device_type: str) -> int:
|
||||
# ProcessGroup can't tell us this info so we have to infer it, assume
|
||||
# homogeneous hardware for now
|
||||
return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
|
||||
|
||||
|
||||
_mesh_resources: _MeshEnv = _MeshEnv()
|
||||
|
||||
|
||||
def _get_device_handle(device_type: str = "cuda"):
|
||||
"""
|
||||
Get the module corresponding to the device_type which is cuda or cuda-like device.
|
||||
For example, when the device_type is cuda, the module `torch.cuda` is returned.
|
||||
Return None when there is no corresponding module for device_type, otherwise
|
||||
return the corresponding module.
|
||||
"""
|
||||
return getattr(torch, device_type, None)
|
||||
|
||||
|
||||
class DeviceMesh:
|
||||
"""
|
||||
DeviceMesh represents a mesh of devices, where layout of devices could be
|
||||
represented as a n-d dimension array, and each value of the n-d dimensional
|
||||
array is the global id of the default process group ranks.
|
||||
|
||||
DeviceMesh could be used to describe the layout of devices across the cluster,
|
||||
and serves as a proxy for communication among the device lists within the cluster.
|
||||
|
||||
We use the default ProcessGroup in this DeviceMesh class to implement proper
|
||||
communications. Note that we also add collective wrappers in this class. This is
|
||||
used to decouple detailed communication backend with the underlying
|
||||
DTensor implementation.
|
||||
|
||||
DeviceMesh can be used as a context manager.
|
||||
|
||||
.. note::
|
||||
DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
|
||||
is running on all processes/ranks in the cluster. Therefore, users need to make sure the
|
||||
`mesh` array (which describes the layout of devices) should be identical across all ranks.
|
||||
Inconsistent `mesh` will lead to silent hang.
|
||||
|
||||
Args:
|
||||
device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
|
||||
mesh (ndarray): could be a multi-dimension array or an integer tensor that
|
||||
describes the layout of devices, the ids are global ids of the
|
||||
default process group.
|
||||
|
||||
Returns:
|
||||
A :class:`DeviceMesh` object
|
||||
|
||||
Example (2 host with 4 GPUs each):
|
||||
```
|
||||
# The following program runs on each process/rank in SPMD manner.
|
||||
# initialize device mesh as (2, 4) to represent the topology
|
||||
# of cross-host(dim 0), and within-host (dim 1)
|
||||
mesh = DeviceMesh(device_type="cuda",
|
||||
mesh=[
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7]
|
||||
])
|
||||
```
|
||||
A reduction over the first dimension of mesh will reduce across
|
||||
columns (0, 4), .. and (3, 7), a reduction over the second dimension
|
||||
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7)
|
||||
|
||||
"""
|
||||
|
||||
device_type: str
|
||||
mesh: torch.Tensor
|
||||
mesh_dim_names: Optional[Tuple[str, ...]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_type: str,
|
||||
mesh: Union[torch.Tensor, "ArrayLike"],
|
||||
*,
|
||||
mesh_dim_names: Optional[Tuple[str, ...]] = None,
|
||||
_init_process_groups: bool = True,
|
||||
) -> None:
|
||||
self.device_type = device_type
|
||||
self.mesh = (
|
||||
mesh.detach()
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, dtype=torch.int)
|
||||
)
|
||||
self.mesh_dim_names = mesh_dim_names
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
self._hash = hash((self._flatten_mesh_list, self.mesh.shape, id(self)))
|
||||
|
||||
# Skip process group initialization if xla device.
|
||||
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
||||
if device_type != "xla":
|
||||
# always try to create default (world) pg, even if it is not initialized
|
||||
# already. The world pg is used for device mesh identity (rank) on each
|
||||
# process (we need to know if the current global rank is in the mesh or not).
|
||||
self._get_or_create_default_group()
|
||||
if _init_process_groups:
|
||||
self._init_process_groups()
|
||||
|
||||
def _get_or_create_default_group(self):
|
||||
default_initialized = is_initialized()
|
||||
if not default_initialized:
|
||||
init_process_group()
|
||||
|
||||
world_size = get_world_size()
|
||||
if self.mesh.numel() > world_size:
|
||||
raise RuntimeError(
|
||||
f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
|
||||
)
|
||||
|
||||
device_handle = _get_device_handle(self.device_type)
|
||||
# TODO: if user want to pass pg_options, offer a way to do it
|
||||
if not default_initialized and device_handle:
|
||||
# automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
|
||||
# NOTE: This device selection would only work for homogeneous hardware.
|
||||
num_devices_per_host = device_handle.device_count()
|
||||
if (
|
||||
world_size > num_devices_per_host
|
||||
and world_size % num_devices_per_host != 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"DeviceMesh only support homogeneous hardware, but found "
|
||||
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
|
||||
)
|
||||
device_handle.set_device(get_rank() % num_devices_per_host)
|
||||
|
||||
# calculate the coordinates of the current global rank on the mesh
|
||||
rank_coords = (self.mesh == get_rank()).nonzero()
|
||||
assert rank_coords.size(0) in (0, 1)
|
||||
self._coordinate_on_dim: Optional[List[int]] = (
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
return _get_default_group()
|
||||
|
||||
def _init_process_groups(self):
|
||||
# group tag/ranks associated with each mesh dimension, each mesh dimension should
|
||||
# have one sub-group per rank
|
||||
dim_group_infos: List[Tuple[str, List[int]]] = []
|
||||
|
||||
if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
|
||||
# if the mesh is the same as world_pg, we just append the default
|
||||
# pg to the first dim groups, as new_group cannot have the exact
|
||||
# same ranks as world
|
||||
dim_group_infos.append(
|
||||
(_get_group_tag(_get_default_group()), list(range(get_world_size())))
|
||||
)
|
||||
else:
|
||||
# create sub pgs base on the mesh argument specified
|
||||
for dim in range(self.mesh.ndim):
|
||||
# swap the current dim to the last dim
|
||||
# then reshape to flatten out other dims
|
||||
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
|
||||
-1, self.mesh.size(dim)
|
||||
)
|
||||
# multi-dim mesh, create subgroups by looping over the pg_ranks
|
||||
# for each dim and append the groups
|
||||
for dim_mesh in pg_ranks_by_dim:
|
||||
subgroup_ranks = dim_mesh.tolist()
|
||||
# call new_group regardless of the current rank in the
|
||||
# pg or not, it's required that all ranks participate
|
||||
# in subgroup construction
|
||||
dim_group = new_group(ranks=subgroup_ranks)
|
||||
# only add to dim_groups if the current rank in the subgroup
|
||||
if self.get_rank() in subgroup_ranks:
|
||||
if len(dim_group_infos) > dim:
|
||||
raise RuntimeError(
|
||||
f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
|
||||
f"in {subgroup_ranks}!"
|
||||
)
|
||||
dim_group_infos.append(
|
||||
(_get_group_tag(dim_group), subgroup_ranks)
|
||||
)
|
||||
self._dim_group_infos = dim_group_infos
|
||||
|
||||
def __enter__(self) -> "DeviceMesh":
|
||||
# set this mesh as the current mesh in mesh env
|
||||
_mesh_resources.mesh_stack.append(self)
|
||||
return self
|
||||
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
|
||||
# pop this mesh from mesh env
|
||||
_mesh_resources.mesh_stack.pop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DeviceMesh({self.mesh.tolist()})"
|
||||
|
||||
def __hash__(self):
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, DeviceMesh):
|
||||
return False
|
||||
if id(self.mesh) == id(other.mesh):
|
||||
return True
|
||||
return (
|
||||
self.mesh.shape == other.mesh.shape
|
||||
and self._flatten_mesh_list == other._flatten_mesh_list
|
||||
)
|
||||
|
||||
def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
|
||||
"""
|
||||
Slice the current DeviceMesh based on the mesh_dim_name given to create a child
|
||||
DeviceMesh.
|
||||
|
||||
Args:
|
||||
mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
|
||||
to create a child DeviceMesh for.
|
||||
Returns:
|
||||
A :class:`DeviceMesh` object
|
||||
|
||||
Example (2 host with 4 GPUs each):
|
||||
```
|
||||
# Below is a DeviceMesh with mesh_shape of (2, 4) and mesh_dim_name of ("dp", "tp")
|
||||
mesh = DeviceMesh(device_type="cuda",
|
||||
mesh=[
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7]
|
||||
],
|
||||
mesh_dim_names=["dp", "tp"])
|
||||
)
|
||||
```
|
||||
Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
|
||||
Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
|
||||
Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
|
||||
Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]).
|
||||
Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]).
|
||||
Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]).
|
||||
"""
|
||||
if self.mesh.ndim <= 1:
|
||||
raise RuntimeError(
|
||||
f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension."
|
||||
)
|
||||
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
|
||||
submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)
|
||||
|
||||
return submesh
|
||||
|
||||
def get_group(
|
||||
self, mesh_dim: Optional[Union[int, str]] = None
|
||||
) -> Union[ProcessGroup, List[ProcessGroup]]:
|
||||
"""
|
||||
Returns a list of ProcessGroups corresponding to the mesh dimensions, or
|
||||
returns a single ProcessGroup if mesh_dim is specified or the given mesh has
|
||||
only one mesh dimension.
|
||||
|
||||
Optional Args:
|
||||
mesh_dim (str/int): it can be the name of the mesh dimension or the index
|
||||
of the mesh dimension. Default is None.
|
||||
Returns:
|
||||
A list of :class:`ProcessGroup` object when `mesh_dim` is not specified for
|
||||
a DeviceMesh with more than 1 dimension; otherwise, returns a single
|
||||
:class:`ProcessGroup` object.
|
||||
"""
|
||||
if not hasattr(self, "_dim_group_infos"):
|
||||
raise RuntimeError("DeviceMesh process groups not initialized!")
|
||||
|
||||
if self.mesh.ndim == 1:
|
||||
return _find_pg_by_ranks_and_tag(*self._dim_group_infos[0])
|
||||
|
||||
if mesh_dim is not None:
|
||||
if isinstance(mesh_dim, str):
|
||||
mesh_dim = self._get_mesh_dim_by_name(mesh_dim)
|
||||
return _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
|
||||
else:
|
||||
dim_groups = []
|
||||
for ith_dim in range(self.mesh.ndim):
|
||||
dim_groups.append(
|
||||
_find_pg_by_ranks_and_tag(*self._dim_group_infos[ith_dim])
|
||||
)
|
||||
return dim_groups
|
||||
|
||||
def size(self, mesh_dim: Optional[int] = None) -> int:
|
||||
return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
return self.mesh.ndim
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
return tuple(self.mesh.shape)
|
||||
|
||||
def get_rank(self) -> int:
|
||||
"""
|
||||
Returns the current global rank.
|
||||
"""
|
||||
return get_rank()
|
||||
|
||||
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
|
||||
"""
|
||||
Returns the local rank of the given mesh_dim of the DeviceMesh.
|
||||
|
||||
Optional Args:
|
||||
mesh_dim (str/int): it can be the name of the mesh dimension or the index
|
||||
of the mesh dimension. Default is None.
|
||||
|
||||
Returns:
|
||||
An integer denotes the local rank.
|
||||
|
||||
Example (2 host with 4 GPUs each):
|
||||
```
|
||||
# Let's initialize device mesh with mesh_shape = (2, 4) to represent
|
||||
# the topology of cross-host(dim 0), and within-host (dim 1).
|
||||
mesh_2d = DeviceMesh(device_type="cuda",
|
||||
mesh=[
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7]
|
||||
])
|
||||
```
|
||||
Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0.
|
||||
Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1.
|
||||
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0.
|
||||
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1.
|
||||
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2.
|
||||
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3.
|
||||
"""
|
||||
if self.ndim > 1 and mesh_dim is None:
|
||||
raise RuntimeError(
|
||||
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
|
||||
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
||||
)
|
||||
elif mesh_dim is None:
|
||||
mesh_dim = 0
|
||||
|
||||
mesh_dim_group = self.get_group(mesh_dim) # type: ignore[arg-type]
|
||||
assert isinstance(
|
||||
mesh_dim_group, ProcessGroup
|
||||
), "We expect ProcessGroup before calling `get_rank`!"
|
||||
return get_rank(mesh_dim_group) # type: ignore[arg-type]
|
||||
|
||||
def get_coordinate(self) -> Optional[List[int]]:
|
||||
"""
|
||||
Return the relative indices of this rank relative to all
|
||||
dimensions of the mesh. If this rank is not part of the mesh, return None.
|
||||
"""
|
||||
return self._coordinate_on_dim if self._coordinate_on_dim else None
|
||||
|
||||
def _get_mesh_dim_by_name(self, mesh_dim_name: str) -> int:
|
||||
if self.mesh_dim_names is None or len(self.mesh_dim_names) == 0:
|
||||
raise KeyError(
|
||||
"No `mesh_dim_names` found.",
|
||||
)
|
||||
if mesh_dim_name not in self.mesh_dim_names:
|
||||
raise KeyError(
|
||||
f"Mesh dimension '{mesh_dim_name}' does not exist.",
|
||||
f"Available mesh dimensions are: {self.mesh_dim_names}",
|
||||
)
|
||||
return self.mesh_dim_names.index(mesh_dim_name) # type: ignore[union-attr]
|
||||
|
||||
|
||||
def init_device_mesh(
|
||||
device_type: str,
|
||||
mesh_shape: Tuple[int, ...],
|
||||
*,
|
||||
mesh_dim_names: Optional[Tuple[str, ...]] = None,
|
||||
) -> DeviceMesh:
|
||||
"""
|
||||
Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
|
||||
This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape)
|
||||
and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is
|
||||
labeled as mesh_dim_names[i].
|
||||
|
||||
.. note::
|
||||
`init_device_mesh` follows SPMD programming model, which means the same PyTorch Python program
|
||||
is running on all processes/ranks in the cluster. Therefore, users need to make sure the `mesh_shape`
|
||||
tuple (the dimension of the nD array that describes the layout of devices) should be identical across
|
||||
all ranks. Inconsistent `mesh_shape` will lead to silent hang.
|
||||
|
||||
Args:
|
||||
device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
|
||||
mesh_shape: Tuple[int]: A tuple defines the dimension of the multi-dimesnion array
|
||||
that describes the layout of devices.
|
||||
Kwargs:
|
||||
mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension
|
||||
of the multi-dimensional array that describes the layout of devices. Its length must match the length
|
||||
of `mesh_shape`. Each string in mesh_dim_names must be unique.
|
||||
|
||||
Returns:
|
||||
A :class:`DeviceMesh` object
|
||||
|
||||
.. note: If no process group is found, init_device_mesh will initialize distributed process group/groups
|
||||
behind the scene, which are required for distributed communications.
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> from torch.distributed._tensor.device_mesh import init_device_mesh
|
||||
>>>
|
||||
>>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
|
||||
>>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
|
||||
"""
|
||||
if mesh_dim_names is not None:
|
||||
if len(set(mesh_dim_names)) != len(mesh_dim_names):
|
||||
raise RuntimeError(
|
||||
"Each mesh_dim_name must be unique.",
|
||||
f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
|
||||
)
|
||||
|
||||
if len(mesh_shape) != len(mesh_dim_names):
|
||||
raise RuntimeError(
|
||||
"mesh_shape and mesh_dim_names should have same length!",
|
||||
f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
|
||||
)
|
||||
|
||||
mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape)
|
||||
device_mesh = DeviceMesh(
|
||||
device_type=device_type,
|
||||
mesh=mesh,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
)
|
||||
|
||||
return device_mesh
|
||||
|
|
@ -7,13 +7,9 @@ import torch.distributed._tensor.ops
|
|||
import torch.distributed._tensor.random as random
|
||||
from torch.distributed._tensor._utils import compute_local_shape
|
||||
from torch.distributed._tensor.api import distribute_module, distribute_tensor, DTensor
|
||||
from torch.distributed._tensor.device_mesh import (
|
||||
_mesh_resources,
|
||||
DeviceMesh,
|
||||
init_device_mesh,
|
||||
)
|
||||
from torch.distributed._tensor.ops.utils import normalize_to_torch_size
|
||||
from torch.distributed._tensor.placement_types import Placement, Replicate, Shard
|
||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
|
||||
|
||||
# All public APIs from dtensor package
|
||||
__all__ = [
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import List, Optional
|
|||
|
||||
import torch
|
||||
import torch.distributed._tensor.placement_types as placement_types
|
||||
from torch.distributed._tensor.device_mesh import _mesh_resources, DeviceMesh
|
||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
||||
from torch.distributed.distributed_c10d import (
|
||||
all_to_all,
|
||||
broadcast,
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ from typing import cast, List, Sequence, Tuple
|
|||
|
||||
import torch
|
||||
from torch._prims_common import ShapeType
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.placement_types import (
|
||||
_Partial,
|
||||
Placement,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
# TODO: audit existing code base to see if we can safely remove this API.
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
|
|||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.placement_types import Placement, Replicate
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import torch.distributed._tensor.random as random
|
|||
import torch.nn as nn
|
||||
from torch.distributed._tensor._collective_utils import mesh_broadcast
|
||||
from torch.distributed._tensor._utils import compute_global_tensor_info
|
||||
from torch.distributed._tensor.device_mesh import _mesh_resources, DeviceMesh
|
||||
from torch.distributed._tensor.placement_types import (
|
||||
DTensorSpec,
|
||||
Placement,
|
||||
|
|
@ -26,6 +25,7 @@ from torch.distributed._tensor.redistribute import (
|
|||
Redistribute,
|
||||
redistribute_local_tensor,
|
||||
)
|
||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
||||
|
||||
|
||||
__all__ = ["DTensor", "distribute_tensor", "distribute_module"]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from torch.distributed._device_mesh import ( # noqa: F401
|
||||
from torch.distributed.device_mesh import ( # noqa: F401
|
||||
_get_device_handle,
|
||||
_mesh_resources,
|
||||
DeviceMesh,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.distributed._tensor.api as dtensor
|
||||
import torch.distributed._tensor.random as random
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.op_schema import (
|
||||
_is_inplace_op,
|
||||
_is_out_variant_op,
|
||||
|
|
@ -24,6 +23,7 @@ from torch.distributed._tensor.tp_conv import (
|
|||
convolution_backward_handler,
|
||||
convolution_handler,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
try:
|
||||
from torch.utils import _cxx_pytree as pytree
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|||
|
||||
import torch
|
||||
from torch._ops import OpOverload
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.placement_types import DTensorSpec
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
try:
|
||||
from torch.utils._cxx_pytree import tree_map_only, TreeSpec
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from dataclasses import dataclass
|
|||
|
||||
from typing import List, Tuple
|
||||
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy
|
||||
from torch.distributed._tensor.placement_types import (
|
||||
_Partial,
|
||||
|
|
@ -13,6 +12,8 @@ from torch.distributed._tensor.placement_types import (
|
|||
Shard,
|
||||
)
|
||||
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
@dataclass
|
||||
class EinsumDims:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from typing import cast, List, Optional, Sequence, Tuple
|
|||
import torch
|
||||
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.op_schema import (
|
||||
OpSchema,
|
||||
OpStrategy,
|
||||
|
|
@ -28,6 +27,7 @@ from torch.distributed._tensor.placement_types import (
|
|||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# implement matrix related ops for distributed tensor
|
||||
import torch
|
||||
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.op_schema import OpSchema, OpStrategy, OutputSharding
|
||||
from torch.distributed._tensor.ops.basic_strategy import gen_einsum_strategies
|
||||
from torch.distributed._tensor.ops.common_rules import einop_rule
|
||||
|
|
@ -16,6 +14,8 @@ from torch.distributed._tensor.ops.utils import (
|
|||
)
|
||||
from torch.distributed._tensor.placement_types import DTensorSpec
|
||||
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
|
||||
from torch.distributed._tensor.op_schema import (
|
||||
_is_inplace_op,
|
||||
|
|
@ -30,6 +29,7 @@ from torch.distributed._tensor.placement_types import (
|
|||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import torch
|
||||
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.op_schema import (
|
||||
OpSchema,
|
||||
OpStrategy,
|
||||
|
|
@ -10,6 +8,8 @@ from torch.distributed._tensor.op_schema import (
|
|||
)
|
||||
from torch.distributed._tensor.ops.utils import is_tensor_partial, register_op_strategy
|
||||
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from typing import cast, List, Optional, Sequence, Tuple
|
|||
import torch
|
||||
|
||||
from torch.distributed._tensor._utils import compute_local_shape
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.op_schema import (
|
||||
OpSchema,
|
||||
OpStrategy,
|
||||
|
|
@ -29,6 +28,7 @@ from torch.distributed._tensor.placement_types import (
|
|||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch.distributed._functional_collectives as funcol
|
|||
import torch.distributed.distributed_c10d as c10d
|
||||
|
||||
from torch.distributed._tensor._collective_utils import mesh_broadcast, mesh_scatter
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
class Placement:
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributed._tensor.device_mesh import _get_device_handle, DeviceMesh
|
||||
from torch.distributed._tensor.placement_types import DTensorSpec, Shard
|
||||
from torch.distributed.device_mesh import _get_device_handle, DeviceMesh
|
||||
|
||||
|
||||
_rng_tracker: Optional["RNGStateTracker"] = None
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import cast, Dict, List, Tuple
|
|||
|
||||
import torch
|
||||
import torch.distributed._tensor.api as dtensor
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.placement_types import (
|
||||
_Partial,
|
||||
DTensorSpec,
|
||||
|
|
@ -11,6 +10,7 @@ from torch.distributed._tensor.placement_types import (
|
|||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
|
||||
_PlacementItem = Tuple[int, Tuple[Placement, Placement]]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
|
|||
import torch
|
||||
from torch._ops import OpOverload
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh
|
||||
from torch.distributed._tensor.op_schema import (
|
||||
DTensorSpec,
|
||||
OpInfo,
|
||||
|
|
@ -19,6 +18,7 @@ from torch.distributed._tensor.op_schema import (
|
|||
TupleStrategy,
|
||||
)
|
||||
from torch.distributed._tensor.placement_types import TensorMeta
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
|
|
|||
524
torch/distributed/device_mesh.py
Normal file
524
torch/distributed/device_mesh.py
Normal file
|
|
@ -0,0 +1,524 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch.distributed import is_available
|
||||
|
||||
__all__ = ["init_device_mesh", "DeviceMesh"]
|
||||
|
||||
|
||||
if not is_available():
|
||||
import sys
|
||||
|
||||
# We need to create the stubs when distributed is not available.
|
||||
# Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```),
|
||||
# since it would try to import ``torch.distributed.device_mesh`` or
|
||||
# ``torch.distributed.init_device_mesh`` but cannot find them.
|
||||
|
||||
class _DeviceMeshStub:
|
||||
pass
|
||||
|
||||
def _init_device_mesh_stub():
|
||||
pass
|
||||
|
||||
sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined]
|
||||
sys.modules[
|
||||
"torch.distributed.device_mesh"
|
||||
].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined]
|
||||
|
||||
|
||||
else:
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_find_pg_by_ranks_and_tag,
|
||||
_get_default_group,
|
||||
_get_group_tag,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
init_process_group,
|
||||
is_initialized,
|
||||
new_group,
|
||||
ProcessGroup,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# only import numpy typing when type checking
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from numpy.typing import ArrayLike
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"DeviceMesh requires numpy >= 1.21 to be installed for type checking"
|
||||
)
|
||||
|
||||
class _MeshEnv:
|
||||
def __init__(self) -> None:
|
||||
self.mesh_stack: List[DeviceMesh] = []
|
||||
self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
|
||||
|
||||
def get_current_mesh(self) -> "DeviceMesh":
|
||||
if len(self.mesh_stack) == 0:
|
||||
raise RuntimeError("No device mesh is currently active!")
|
||||
return self.mesh_stack[-1]
|
||||
|
||||
def create_child_mesh(
|
||||
self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str
|
||||
) -> "DeviceMesh":
|
||||
# swap the current dim to the last dim then reshape to flatten out other
|
||||
# dims, so we can just extract the list of ranks which contains cur_rank.
|
||||
cur_rank = device_mesh.get_rank()
|
||||
pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
|
||||
-1, device_mesh.mesh.size(mesh_dim)
|
||||
)
|
||||
|
||||
for mesh_1d in pg_ranks_by_dim:
|
||||
sub_mesh = DeviceMesh(
|
||||
device_mesh.device_type,
|
||||
mesh_1d,
|
||||
mesh_dim_names=(mesh_dim_name,),
|
||||
_init_process_groups=False,
|
||||
)
|
||||
if cur_rank in mesh_1d:
|
||||
res_sub_mesh = sub_mesh
|
||||
|
||||
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
|
||||
# Assign the current DeviceMesh as the parent of the child DeviceMesh.
|
||||
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
|
||||
return res_sub_mesh
|
||||
|
||||
def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
|
||||
return self.child_to_parent_mapping.get(device_mesh, None)
|
||||
|
||||
def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
|
||||
"""
|
||||
Return the index of the mesh dim in the parent mesh.
|
||||
The device_mesh passed in needs to be sliced out from a parent mesh.
|
||||
"""
|
||||
parent_mesh = self.get_parent_mesh(device_mesh)
|
||||
child_mesh_dim_names = device_mesh.mesh_dim_names
|
||||
if parent_mesh and child_mesh_dim_names:
|
||||
assert (
|
||||
len(child_mesh_dim_names) == 1
|
||||
), "The child mesh can only be a 1D mesh."
|
||||
child_mesh_dim_name = child_mesh_dim_names[0]
|
||||
if parent_mesh.mesh_dim_names:
|
||||
return parent_mesh._get_mesh_dim_by_name(child_mesh_dim_name)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def num_devices_per_host(device_type: str) -> int:
|
||||
return _get_device_handle(device_type).device_count()
|
||||
|
||||
@staticmethod
|
||||
def num_hosts(device_type: str) -> int:
|
||||
# ProcessGroup can't tell us this info so we have to infer it, assume
|
||||
# homogeneous hardware for now
|
||||
return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
|
||||
|
||||
_mesh_resources: _MeshEnv = _MeshEnv()
|
||||
|
||||
def _get_device_handle(device_type: str = "cuda"):
|
||||
"""
|
||||
Get the module corresponding to the device_type which is cuda or cuda-like device.
|
||||
For example, when the device_type is cuda, the module `torch.cuda` is returned.
|
||||
Return None when there is no corresponding module for device_type, otherwise
|
||||
return the corresponding module.
|
||||
"""
|
||||
return getattr(torch, device_type, None)
|
||||
|
||||
class DeviceMesh:
|
||||
"""
|
||||
DeviceMesh represents a mesh of devices, where layout of devices could be
|
||||
represented as a n-d dimension array, and each value of the n-d dimensional
|
||||
array is the global id of the default process group ranks.
|
||||
|
||||
DeviceMesh could be used to describe the layout of devices across the cluster,
|
||||
and serves as a proxy for communication among the device lists within the cluster.
|
||||
|
||||
We use the default ProcessGroup in this DeviceMesh class to implement proper
|
||||
communications. Note that we also add collective wrappers in this class. This is
|
||||
used to decouple detailed communication backend with the underlying
|
||||
DTensor implementation.
|
||||
|
||||
DeviceMesh can be used as a context manager.
|
||||
|
||||
.. note::
|
||||
DeviceMesh follows SPMD programming model, which means the same PyTorch Python program
|
||||
is running on all processes/ranks in the cluster. Therefore, users need to make sure the
|
||||
`mesh` array (which describes the layout of devices) should be identical across all ranks.
|
||||
Inconsistent `mesh` will lead to silent hang.
|
||||
|
||||
Args:
|
||||
device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
|
||||
mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout
|
||||
of devices, where the IDs are global IDs of the default process group.
|
||||
|
||||
Returns:
|
||||
DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
|
||||
|
||||
Example (2 host with 4 GPUs each):
|
||||
The following program runs on each process/rank in an SPMD manner:
|
||||
>>> # xdoctest: +SKIP("no rank")
|
||||
>>>
|
||||
>>> from torch.distributed import DeviceMesh
|
||||
>>> # Initialize device mesh as (2, 4) to represent the topology
|
||||
>>> # of cross-host(dim 0), and within-host (dim 1).
|
||||
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
|
||||
|
||||
A reduction over the first dimension of mesh will reduce across
|
||||
columns (0, 4), .. and (3, 7), a reduction over the second dimension
|
||||
of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7).
|
||||
"""
|
||||
|
||||
device_type: str
|
||||
mesh: torch.Tensor
|
||||
mesh_dim_names: Optional[Tuple[str, ...]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_type: str,
|
||||
mesh: Union[torch.Tensor, "ArrayLike"],
|
||||
*,
|
||||
mesh_dim_names: Optional[Tuple[str, ...]] = None,
|
||||
_init_process_groups: bool = True,
|
||||
) -> None:
|
||||
self.device_type = device_type
|
||||
self.mesh = (
|
||||
mesh.detach()
|
||||
if isinstance(mesh, torch.Tensor)
|
||||
else torch.tensor(mesh, dtype=torch.int)
|
||||
)
|
||||
self.mesh_dim_names = mesh_dim_names
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
self._hash = hash((self._flatten_mesh_list, self.mesh.shape, id(self)))
|
||||
|
||||
# Skip process group initialization if xla device.
|
||||
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
||||
if device_type != "xla":
|
||||
# always try to create default (world) pg, even if it is not initialized
|
||||
# already. The world pg is used for device mesh identity (rank) on each
|
||||
# process (we need to know if the current global rank is in the mesh or not).
|
||||
self._get_or_create_default_group()
|
||||
if _init_process_groups:
|
||||
self._init_process_groups()
|
||||
|
||||
def _get_or_create_default_group(self):
|
||||
default_initialized = is_initialized()
|
||||
if not default_initialized:
|
||||
init_process_group()
|
||||
|
||||
world_size = get_world_size()
|
||||
if self.mesh.numel() > world_size:
|
||||
raise RuntimeError(
|
||||
f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
|
||||
)
|
||||
|
||||
device_handle = _get_device_handle(self.device_type)
|
||||
# TODO: if user want to pass pg_options, offer a way to do it
|
||||
if not default_initialized and device_handle:
|
||||
# automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
|
||||
# NOTE: This device selection would only work for homogeneous hardware.
|
||||
num_devices_per_host = device_handle.device_count()
|
||||
if (
|
||||
world_size > num_devices_per_host
|
||||
and world_size % num_devices_per_host != 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"DeviceMesh only support homogeneous hardware, but found "
|
||||
f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
|
||||
)
|
||||
device_handle.set_device(get_rank() % num_devices_per_host)
|
||||
|
||||
# calculate the coordinates of the current global rank on the mesh
|
||||
rank_coords = (self.mesh == get_rank()).nonzero()
|
||||
assert rank_coords.size(0) in (0, 1)
|
||||
self._coordinate_on_dim: Optional[List[int]] = (
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
return _get_default_group()
|
||||
|
||||
def _init_process_groups(self):
|
||||
# group tag/ranks associated with each mesh dimension, each mesh dimension should
|
||||
# have one sub-group per rank
|
||||
dim_group_infos: List[Tuple[str, List[int]]] = []
|
||||
|
||||
if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
|
||||
# if the mesh is the same as world_pg, we just append the default
|
||||
# pg to the first dim groups, as new_group cannot have the exact
|
||||
# same ranks as world
|
||||
dim_group_infos.append(
|
||||
(
|
||||
_get_group_tag(_get_default_group()),
|
||||
list(range(get_world_size())),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# create sub pgs base on the mesh argument specified
|
||||
for dim in range(self.mesh.ndim):
|
||||
# swap the current dim to the last dim
|
||||
# then reshape to flatten out other dims
|
||||
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
|
||||
-1, self.mesh.size(dim)
|
||||
)
|
||||
# multi-dim mesh, create subgroups by looping over the pg_ranks
|
||||
# for each dim and append the groups
|
||||
for dim_mesh in pg_ranks_by_dim:
|
||||
subgroup_ranks = dim_mesh.tolist()
|
||||
# call new_group regardless of the current rank in the
|
||||
# pg or not, it's required that all ranks participate
|
||||
# in subgroup construction
|
||||
dim_group = new_group(ranks=subgroup_ranks)
|
||||
# only add to dim_groups if the current rank in the subgroup
|
||||
if self.get_rank() in subgroup_ranks:
|
||||
if len(dim_group_infos) > dim:
|
||||
raise RuntimeError(
|
||||
f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
|
||||
f"in {subgroup_ranks}!"
|
||||
)
|
||||
dim_group_infos.append(
|
||||
(_get_group_tag(dim_group), subgroup_ranks)
|
||||
)
|
||||
self._dim_group_infos = dim_group_infos
|
||||
|
||||
def __enter__(self) -> "DeviceMesh":
|
||||
# set this mesh as the current mesh in mesh env
|
||||
_mesh_resources.mesh_stack.append(self)
|
||||
return self
|
||||
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
|
||||
# pop this mesh from mesh env
|
||||
_mesh_resources.mesh_stack.pop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DeviceMesh({self.mesh.tolist()})"
|
||||
|
||||
def __hash__(self):
|
||||
return self._hash
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, DeviceMesh):
|
||||
return False
|
||||
if id(self.mesh) == id(other.mesh):
|
||||
return True
|
||||
return (
|
||||
self.mesh.shape == other.mesh.shape
|
||||
and self._flatten_mesh_list == other._flatten_mesh_list
|
||||
)
|
||||
|
||||
def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
|
||||
"""
|
||||
Slice the current DeviceMesh based on the mesh_dim_name given to create a child
|
||||
DeviceMesh.
|
||||
|
||||
Args:
|
||||
mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
|
||||
to create a child DeviceMesh for.
|
||||
Returns:
|
||||
A :class:`DeviceMesh` object
|
||||
|
||||
Example (2 host with 4 GPUs each):
|
||||
The following program runs on each process/rank in an SPMD manner:
|
||||
>>> # xdoctest: +SKIP("no rank")
|
||||
>>>
|
||||
>>> from torch.distributed import DeviceMesh
|
||||
>>> # Initialize device mesh as (2, 4) to represent the topology
|
||||
>>> # of cross-host(dim 0), and within-host (dim 1).
|
||||
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
|
||||
|
||||
Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
|
||||
Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
|
||||
Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
|
||||
Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]).
|
||||
Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]).
|
||||
Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]).
|
||||
|
||||
"""
|
||||
if self.mesh.ndim <= 1:
|
||||
raise RuntimeError(
|
||||
f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension."
|
||||
)
|
||||
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
|
||||
submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)
|
||||
|
||||
return submesh
|
||||
|
||||
def get_group(
|
||||
self, mesh_dim: Optional[Union[int, str]] = None
|
||||
) -> Union[ProcessGroup, List[ProcessGroup]]:
|
||||
"""
|
||||
Returns a list of ProcessGroups corresponding to the mesh dimensions, or
|
||||
returns a single ProcessGroup if mesh_dim is specified or the given mesh has
|
||||
only one mesh dimension.
|
||||
|
||||
Args:
|
||||
mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
|
||||
of the mesh dimension. Default is None.
|
||||
|
||||
Returns:
|
||||
A list of :class:`ProcessGroup` object when `mesh_dim` is not specified for
|
||||
a DeviceMesh with more than 1 dimension; otherwise, returns a single
|
||||
:class:`ProcessGroup` object.
|
||||
"""
|
||||
if not hasattr(self, "_dim_group_infos"):
|
||||
raise RuntimeError("DeviceMesh process groups not initialized!")
|
||||
|
||||
if self.mesh.ndim == 1:
|
||||
return _find_pg_by_ranks_and_tag(*self._dim_group_infos[0])
|
||||
|
||||
if mesh_dim is not None:
|
||||
if isinstance(mesh_dim, str):
|
||||
mesh_dim = self._get_mesh_dim_by_name(mesh_dim)
|
||||
return _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
|
||||
else:
|
||||
dim_groups = []
|
||||
for ith_dim in range(self.mesh.ndim):
|
||||
dim_groups.append(
|
||||
_find_pg_by_ranks_and_tag(*self._dim_group_infos[ith_dim])
|
||||
)
|
||||
return dim_groups
|
||||
|
||||
def size(self, mesh_dim: Optional[int] = None) -> int:
|
||||
return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
return self.mesh.ndim
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
return tuple(self.mesh.shape)
|
||||
|
||||
def get_rank(self) -> int:
|
||||
"""
|
||||
Returns the current global rank.
|
||||
"""
|
||||
return get_rank()
|
||||
|
||||
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
|
||||
"""
|
||||
Returns the local rank of the given mesh_dim of the DeviceMesh.
|
||||
|
||||
Args:
|
||||
mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index
|
||||
of the mesh dimension. Default is None.
|
||||
|
||||
Returns:
|
||||
An integer denotes the local rank.
|
||||
|
||||
Example (2 host with 4 GPUs each):
|
||||
The following program runs on each process/rank in an SPMD manner:
|
||||
>>> # xdoctest: +SKIP("no rank")
|
||||
>>>
|
||||
>>> from torch.distributed import DeviceMesh
|
||||
>>> # Initialize device mesh as (2, 4) to represent the topology
|
||||
>>> # of cross-host(dim 0), and within-host (dim 1).
|
||||
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
|
||||
|
||||
Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0.
|
||||
Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1.
|
||||
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0.
|
||||
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1.
|
||||
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2.
|
||||
Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3.
|
||||
|
||||
"""
|
||||
if self.ndim > 1 and mesh_dim is None:
|
||||
raise RuntimeError(
|
||||
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
|
||||
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
||||
)
|
||||
elif mesh_dim is None:
|
||||
mesh_dim = 0
|
||||
|
||||
mesh_dim_group = self.get_group(mesh_dim) # type: ignore[arg-type]
|
||||
assert isinstance(
|
||||
mesh_dim_group, ProcessGroup
|
||||
), "We expect ProcessGroup before calling `get_rank`!"
|
||||
return get_rank(mesh_dim_group) # type: ignore[arg-type]
|
||||
|
||||
def get_coordinate(self) -> Optional[List[int]]:
|
||||
"""
|
||||
Return the relative indices of this rank relative to all
|
||||
dimensions of the mesh. If this rank is not part of the mesh, return None.
|
||||
"""
|
||||
return self._coordinate_on_dim if self._coordinate_on_dim else None
|
||||
|
||||
def _get_mesh_dim_by_name(self, mesh_dim_name: str) -> int:
|
||||
if self.mesh_dim_names is None or len(self.mesh_dim_names) == 0:
|
||||
raise KeyError(
|
||||
"No `mesh_dim_names` found.",
|
||||
)
|
||||
if mesh_dim_name not in self.mesh_dim_names:
|
||||
raise KeyError(
|
||||
f"Mesh dimension '{mesh_dim_name}' does not exist.",
|
||||
f"Available mesh dimensions are: {self.mesh_dim_names}",
|
||||
)
|
||||
return self.mesh_dim_names.index(mesh_dim_name) # type: ignore[union-attr]
|
||||
|
||||
def init_device_mesh(
|
||||
device_type: str,
|
||||
mesh_shape: Tuple[int, ...],
|
||||
*,
|
||||
mesh_dim_names: Optional[Tuple[str, ...]] = None,
|
||||
) -> DeviceMesh:
|
||||
"""
|
||||
Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
|
||||
|
||||
This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`.
|
||||
If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`.
|
||||
|
||||
.. note::
|
||||
`init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program
|
||||
runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array
|
||||
describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging.
|
||||
|
||||
.. note::
|
||||
If no process group is found, init_device_mesh will initialize distributed process group/groups
|
||||
required for distributed communications behind the scene.
|
||||
|
||||
Args:
|
||||
device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
|
||||
mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array
|
||||
describing the layout of devices.
|
||||
mesh_dim_names (Tuple[str], optional): A tuple of mesh dimension names to assign to each dimension
|
||||
of the multi-dimensional array describing the layout of devices. Its length must match the length
|
||||
of `mesh_shape`. Each string in `mesh_dim_names` must be unique.
|
||||
|
||||
Returns:
|
||||
DeviceMesh: A :class:`DeviceMesh` object representing the device layout.
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("no rank")
|
||||
>>> from torch.distributed import init_device_mesh
|
||||
>>>
|
||||
>>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
|
||||
>>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
|
||||
|
||||
"""
|
||||
if mesh_dim_names is not None:
|
||||
if len(set(mesh_dim_names)) != len(mesh_dim_names):
|
||||
raise RuntimeError(
|
||||
"Each mesh_dim_name must be unique.",
|
||||
f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
|
||||
)
|
||||
|
||||
if len(mesh_shape) != len(mesh_dim_names):
|
||||
raise RuntimeError(
|
||||
"mesh_shape and mesh_dim_names should have same length!",
|
||||
f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
|
||||
)
|
||||
|
||||
mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape)
|
||||
device_mesh = DeviceMesh(
|
||||
device_type=device_type,
|
||||
mesh=mesh,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
)
|
||||
|
||||
return device_mesh
|
||||
|
|
@ -24,8 +24,8 @@ import torch.distributed.fsdp._exec_order_utils as exec_order_utils
|
|||
import torch.distributed.fsdp._traversal_utils as traversal_utils
|
||||
import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
|
||||
import torch.nn as nn
|
||||
from torch.distributed._tensor.device_mesh import _mesh_resources, DeviceMesh
|
||||
from torch.distributed.algorithms._comm_hooks import default_hooks
|
||||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.distributed.fsdp._common_utils import (
|
||||
_FSDPDeviceHandle,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from torch.distributed._shard.sharded_tensor import (
|
|||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed._tensor.device_mesh import _mesh_resources
|
||||
from torch.distributed.device_mesh import _mesh_resources
|
||||
|
||||
from torch.distributed.fsdp._common_utils import (
|
||||
_FSDPState,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import functools
|
||||
import warnings
|
||||
from typing import Callable, Tuple, Optional, Union
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor
|
||||
from torch.distributed._tensor.device_mesh import _mesh_resources
|
||||
from torch.distributed._tensor.placement_types import Placement
|
||||
from torch.distributed.device_mesh import _mesh_resources
|
||||
try:
|
||||
from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from torch.distributed._shard.sharded_tensor import (
|
|||
from torch.distributed._shard.sharding_spec import ShardMetadata
|
||||
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
|
||||
from torch.distributed._tensor.device_mesh import _mesh_resources
|
||||
from torch.distributed.device_mesh import _mesh_resources
|
||||
|
||||
from torch.distributed.fsdp._common_utils import _set_fsdp_flattened
|
||||
from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
|
||||
|
|
|
|||
|
|
@ -10049,7 +10049,7 @@ class DistributedTest:
|
|||
"""
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
from torch.distributed._device_mesh import init_device_mesh
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
device_mesh = init_device_mesh("cuda", (world_size,))
|
||||
|
||||
pg = _get_default_group()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user