[2/N][DeviceMesh] Overriding __getitem__ for DeviceMesh to support Mesh Slicing (#107730)

Add support for DeviceMesh slicing by overloading __getitem__ for DeviceMesh.

With this change, you can do:
```
mesh_shape = (2, 4)
mesh_dim_names = ("DP", "TP")
two_d_mesh = init_device_mesh(
    self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
tp_mesh = two_d_mesh["TP"]
```

cc. @wanchaol, @fduwjj
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107730
Approved by: https://github.com/wanchaol
This commit is contained in:
wz337 2023-08-23 20:35:26 +00:00 committed by PyTorch MergeBot
parent 652ccfadc1
commit cdd0821f00
2 changed files with 147 additions and 2 deletions

View File

@ -9,7 +9,11 @@ from torch.distributed._tensor._collective_utils import (
mesh_broadcast,
mesh_scatter,
)
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed._tensor.device_mesh import (
DeviceMesh,
init_device_mesh,
mesh_resources,
)
from torch.distributed._tensor.placement_types import Shard
from torch.distributed.distributed_c10d import (
@ -180,6 +184,72 @@ class InitDeviceMeshTest(DTensorTestBase):
self.assertEqual(two_d_mesh, ref_mesh)
class TestDeviceMeshGetItem(DTensorTestBase):
@property
def world_size(self):
return 8
@with_comms
def test_raises_mesh_dim_less_than_2(self):
with self.assertRaisesRegex(RuntimeError, "Cannot slice a DeviceMesh"):
mesh = init_device_mesh(self.device_type, (8,))
child_mesh = mesh["DP"]
@with_comms
def test_raises_no_mesh_dim_found(self):
with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found."):
mesh = init_device_mesh(self.device_type, (2, 4))
child_mesh = mesh["DP"]
@with_comms
def test_raises_invalid_mesh_dim_name(self):
child_mesh_dim_name = "PP"
with self.assertRaisesRegex(
KeyError, f"Mesh dimension '{child_mesh_dim_name}' does not exist."
):
mesh_dim_names = ("DP", "TP")
mesh = init_device_mesh(
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
)
child_mesh = mesh[child_mesh_dim_name]
@with_comms
def test_get_item(self):
mesh_shape = (2, 4)
mesh_dim_names = ("DP", "TP")
two_d_mesh = init_device_mesh(
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
pg_ranks_by_dim_name = {}
for mesh_dim_name in mesh_dim_names:
mesh_dim = mesh_dim_names.index(mesh_dim_name)
pg_ranks_by_dim_name[mesh_dim_name] = two_d_mesh.mesh.swapdims(
-1, mesh_dim
).reshape(-1, two_d_mesh.mesh.size(mesh_dim))
tp_mesh = two_d_mesh["TP"]
tp_group_idx = self.rank // 4
self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx])
dp_mesh = two_d_mesh["DP"]
dp_group_idx = self.rank % 4
self.assertEqual(
two_d_mesh["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx]
)
@with_comms
def test_get_parent_mesh(self):
mesh_shape = (2, 4)
mesh_dim_names = ("DP", "TP")
two_d_mesh = init_device_mesh(
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
self.assertEqual(mesh_resources.get_parent_mesh(two_d_mesh["DP"]), two_d_mesh)
self.assertEqual(mesh_resources.get_parent_mesh(two_d_mesh["TP"]), two_d_mesh)
class DeviceMeshCollectiveTest(DTensorTestBase):
@property
def world_size(self):

View File

@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
import math
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
import torch.distributed._functional_collectives as funcol
@ -34,12 +34,38 @@ if TYPE_CHECKING:
class _MeshEnv:
def __init__(self) -> None:
self.mesh_stack: List[DeviceMesh] = []
self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
def get_current_mesh(self) -> "DeviceMesh":
if len(self.mesh_stack) == 0:
raise RuntimeError("No device mesh is currently active!")
return self.mesh_stack[-1]
def create_child_mesh(
self, device_mesh: "DeviceMesh", mesh_dim: int
) -> "DeviceMesh":
# swap the current dim to the last dim then reshape to flatten out other
# dims, so we can just extract the list of ranks which contains cur_rank.
cur_rank = device_mesh.get_rank()
pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
-1, device_mesh.mesh.size(mesh_dim)
)
for mesh_1d in pg_ranks_by_dim:
sub_mesh = DeviceMesh(
device_mesh.device_type, mesh_1d, _init_process_groups=False
)
if cur_rank in mesh_1d:
res_sub_mesh = sub_mesh
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
# Assign the current DeviceMesh as the parent of the child DeviceMesh.
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
return res_sub_mesh
def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
return self.child_to_parent_mapping.get(device_mesh, None)
mesh_resources: _MeshEnv = _MeshEnv()
@ -97,6 +123,7 @@ class DeviceMesh:
device_type: str
mesh: torch.Tensor
mesh_dim_names: Optional[Tuple[str, ...]]
def __init__(
self,
@ -241,6 +268,54 @@ class DeviceMesh:
return True
return self.mesh.equal(other.mesh)
def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
"""
Slice the current DeviceMesh based on the mesh_dim_name given to create a child
DeviceMesh.
Args:
mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
to create a child DeviceMesh for.
Returns:
A :class:`DeviceMesh` object
Example (2 host with 4 GPUs each):
```
# Below is a DeviceMesh with mesh_shape of (2, 4) and mesh_dim_name of ("dp", "tp")
mesh = DeviceMesh(device_type="cuda",
mesh=[
[0, 1, 2, 3],
[4, 5, 6, 7]
],
mesh_dim_names=["dp", "tp"])
)
```
Calling mesh["dp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
Calling mesh["dp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
Calling mesh["tp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
Calling mesh["tp"] on rank 1, 3 would return a 1D child DeviceMesh:([1, 3]).
Calling mesh["tp"] on rank 2, 5 would return a 1D child DeviceMesh:([2, 5]).
Calling mesh["tp"] on rank 4, 7 would return a 1D child DeviceMesh:([4, 7]).
"""
if self.mesh.ndim <= 1:
raise RuntimeError(
f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension."
)
if self.mesh_dim_names is None:
raise KeyError(
"No `mesh_dim_names` found.",
"To slice the device mesh, please call `init_device_mesh` with `mesh_dim_names`.",
)
if mesh_dim_name not in self.mesh_dim_names:
raise KeyError(
f"Mesh dimension '{mesh_dim_name}' does not exist.",
f"Available mesh dimensions are: {self.mesh_dim_names}",
)
mesh_dim = self.mesh_dim_names.index(mesh_dim_name)
submesh = mesh_resources.create_child_mesh(self, mesh_dim)
return submesh
def get_dim_groups(
self, mesh_dim: Optional[int] = None
) -> Union[ProcessGroup, List[ProcessGroup]]: