mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[2/N][DeviceMesh] Overriding __getitem__ for DeviceMesh to support Mesh Slicing (#107730)
Add support for DeviceMesh slicing by overloading __getitem__ for DeviceMesh.
With this change, you can do:
```
mesh_shape = (2, 4)
mesh_dim_names = ("DP", "TP")
two_d_mesh = init_device_mesh(
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
tp_mesh = two_d_mesh["TP"]
```
cc. @wanchaol, @fduwjj
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107730
Approved by: https://github.com/wanchaol
This commit is contained in:
parent
652ccfadc1
commit
cdd0821f00
|
|
@ -9,7 +9,11 @@ from torch.distributed._tensor._collective_utils import (
|
|||
mesh_broadcast,
|
||||
mesh_scatter,
|
||||
)
|
||||
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed._tensor.device_mesh import (
|
||||
DeviceMesh,
|
||||
init_device_mesh,
|
||||
mesh_resources,
|
||||
)
|
||||
from torch.distributed._tensor.placement_types import Shard
|
||||
|
||||
from torch.distributed.distributed_c10d import (
|
||||
|
|
@ -180,6 +184,72 @@ class InitDeviceMeshTest(DTensorTestBase):
|
|||
self.assertEqual(two_d_mesh, ref_mesh)
|
||||
|
||||
|
||||
class TestDeviceMeshGetItem(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 8
|
||||
|
||||
@with_comms
|
||||
def test_raises_mesh_dim_less_than_2(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Cannot slice a DeviceMesh"):
|
||||
mesh = init_device_mesh(self.device_type, (8,))
|
||||
child_mesh = mesh["DP"]
|
||||
|
||||
@with_comms
|
||||
def test_raises_no_mesh_dim_found(self):
|
||||
with self.assertRaisesRegex(KeyError, "No `mesh_dim_names` found."):
|
||||
mesh = init_device_mesh(self.device_type, (2, 4))
|
||||
child_mesh = mesh["DP"]
|
||||
|
||||
@with_comms
|
||||
def test_raises_invalid_mesh_dim_name(self):
|
||||
child_mesh_dim_name = "PP"
|
||||
with self.assertRaisesRegex(
|
||||
KeyError, f"Mesh dimension '{child_mesh_dim_name}' does not exist."
|
||||
):
|
||||
mesh_dim_names = ("DP", "TP")
|
||||
mesh = init_device_mesh(
|
||||
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
child_mesh = mesh[child_mesh_dim_name]
|
||||
|
||||
@with_comms
|
||||
def test_get_item(self):
|
||||
mesh_shape = (2, 4)
|
||||
mesh_dim_names = ("DP", "TP")
|
||||
two_d_mesh = init_device_mesh(
|
||||
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
|
||||
pg_ranks_by_dim_name = {}
|
||||
for mesh_dim_name in mesh_dim_names:
|
||||
mesh_dim = mesh_dim_names.index(mesh_dim_name)
|
||||
pg_ranks_by_dim_name[mesh_dim_name] = two_d_mesh.mesh.swapdims(
|
||||
-1, mesh_dim
|
||||
).reshape(-1, two_d_mesh.mesh.size(mesh_dim))
|
||||
|
||||
tp_mesh = two_d_mesh["TP"]
|
||||
tp_group_idx = self.rank // 4
|
||||
self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx])
|
||||
|
||||
dp_mesh = two_d_mesh["DP"]
|
||||
dp_group_idx = self.rank % 4
|
||||
self.assertEqual(
|
||||
two_d_mesh["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx]
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_get_parent_mesh(self):
|
||||
mesh_shape = (2, 4)
|
||||
mesh_dim_names = ("DP", "TP")
|
||||
two_d_mesh = init_device_mesh(
|
||||
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
|
||||
self.assertEqual(mesh_resources.get_parent_mesh(two_d_mesh["DP"]), two_d_mesh)
|
||||
self.assertEqual(mesh_resources.get_parent_mesh(two_d_mesh["TP"]), two_d_mesh)
|
||||
|
||||
|
||||
class DeviceMeshCollectiveTest(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
|
@ -34,12 +34,38 @@ if TYPE_CHECKING:
|
|||
class _MeshEnv:
|
||||
def __init__(self) -> None:
|
||||
self.mesh_stack: List[DeviceMesh] = []
|
||||
self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
|
||||
|
||||
def get_current_mesh(self) -> "DeviceMesh":
|
||||
if len(self.mesh_stack) == 0:
|
||||
raise RuntimeError("No device mesh is currently active!")
|
||||
return self.mesh_stack[-1]
|
||||
|
||||
def create_child_mesh(
|
||||
self, device_mesh: "DeviceMesh", mesh_dim: int
|
||||
) -> "DeviceMesh":
|
||||
# swap the current dim to the last dim then reshape to flatten out other
|
||||
# dims, so we can just extract the list of ranks which contains cur_rank.
|
||||
cur_rank = device_mesh.get_rank()
|
||||
pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
|
||||
-1, device_mesh.mesh.size(mesh_dim)
|
||||
)
|
||||
|
||||
for mesh_1d in pg_ranks_by_dim:
|
||||
sub_mesh = DeviceMesh(
|
||||
device_mesh.device_type, mesh_1d, _init_process_groups=False
|
||||
)
|
||||
if cur_rank in mesh_1d:
|
||||
res_sub_mesh = sub_mesh
|
||||
|
||||
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
|
||||
# Assign the current DeviceMesh as the parent of the child DeviceMesh.
|
||||
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
|
||||
return res_sub_mesh
|
||||
|
||||
def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
|
||||
return self.child_to_parent_mapping.get(device_mesh, None)
|
||||
|
||||
|
||||
mesh_resources: _MeshEnv = _MeshEnv()
|
||||
|
||||
|
|
@ -97,6 +123,7 @@ class DeviceMesh:
|
|||
|
||||
device_type: str
|
||||
mesh: torch.Tensor
|
||||
mesh_dim_names: Optional[Tuple[str, ...]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -241,6 +268,54 @@ class DeviceMesh:
|
|||
return True
|
||||
return self.mesh.equal(other.mesh)
|
||||
|
||||
def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
|
||||
"""
|
||||
Slice the current DeviceMesh based on the mesh_dim_name given to create a child
|
||||
DeviceMesh.
|
||||
|
||||
Args:
|
||||
mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
|
||||
to create a child DeviceMesh for.
|
||||
Returns:
|
||||
A :class:`DeviceMesh` object
|
||||
|
||||
Example (2 host with 4 GPUs each):
|
||||
```
|
||||
# Below is a DeviceMesh with mesh_shape of (2, 4) and mesh_dim_name of ("dp", "tp")
|
||||
mesh = DeviceMesh(device_type="cuda",
|
||||
mesh=[
|
||||
[0, 1, 2, 3],
|
||||
[4, 5, 6, 7]
|
||||
],
|
||||
mesh_dim_names=["dp", "tp"])
|
||||
)
|
||||
```
|
||||
Calling mesh["dp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
|
||||
Calling mesh["dp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
|
||||
Calling mesh["tp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
|
||||
Calling mesh["tp"] on rank 1, 3 would return a 1D child DeviceMesh:([1, 3]).
|
||||
Calling mesh["tp"] on rank 2, 5 would return a 1D child DeviceMesh:([2, 5]).
|
||||
Calling mesh["tp"] on rank 4, 7 would return a 1D child DeviceMesh:([4, 7]).
|
||||
"""
|
||||
if self.mesh.ndim <= 1:
|
||||
raise RuntimeError(
|
||||
f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension."
|
||||
)
|
||||
if self.mesh_dim_names is None:
|
||||
raise KeyError(
|
||||
"No `mesh_dim_names` found.",
|
||||
"To slice the device mesh, please call `init_device_mesh` with `mesh_dim_names`.",
|
||||
)
|
||||
if mesh_dim_name not in self.mesh_dim_names:
|
||||
raise KeyError(
|
||||
f"Mesh dimension '{mesh_dim_name}' does not exist.",
|
||||
f"Available mesh dimensions are: {self.mesh_dim_names}",
|
||||
)
|
||||
mesh_dim = self.mesh_dim_names.index(mesh_dim_name)
|
||||
submesh = mesh_resources.create_child_mesh(self, mesh_dim)
|
||||
|
||||
return submesh
|
||||
|
||||
def get_dim_groups(
|
||||
self, mesh_dim: Optional[int] = None
|
||||
) -> Union[ProcessGroup, List[ProcessGroup]]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user