[DeviceMesh] Cache and reuse sliced result (#122975)

Fixes #118849

Add a map for parent_to_child_mappings in _mesh_resources so we can cache and reuse submesh slicing result so that we can avoid recreating submesh and the underlying sub pg repeatedly, which could lead to funky behaviors.

We will follow up with reusing pg from the parent_mesh during submesh creation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122975
Approved by: https://github.com/wanchaol
This commit is contained in:
wz337 2024-03-30 23:56:52 +00:00 committed by PyTorch MergeBot
parent 35c493f2cf
commit 2b1ba0ceae
2 changed files with 35 additions and 4 deletions

View File

@ -14,6 +14,7 @@ from torch.distributed._tensor.placement_types import _Partial, Shard
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
from torch.distributed.distributed_c10d import ( from torch.distributed.distributed_c10d import (
_world,
get_global_rank, get_global_rank,
get_world_size, get_world_size,
init_process_group, init_process_group,
@ -320,6 +321,24 @@ class TestDeviceMeshGetItem(DTensorTestBase):
with self.assertRaisesRegex(RuntimeError, "Invalid mesh_dim_name"): with self.assertRaisesRegex(RuntimeError, "Invalid mesh_dim_name"):
dp_mesh = mesh["dim0"] dp_mesh = mesh["dim0"]
@with_comms
@run_with_both_funcol_impls
def test_cache_and_reuse_submesh_slice_result(self):
mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp"))
dp_mesh = mesh["dp"]
ref_pg_count = _world.group_count
# When we call the "dp" slice second time, it should not create any new pg.
# As we are just using the cached result so the pg count should be the same.
dp_mesh_2 = mesh["dp"]
self.assertEqual(ref_pg_count, _world.group_count)
# When we call the "tp" slice, it should create a new pg, as the "tp" slice is called
# for the first time.
tp_mesh = mesh["tp"]
self.assertTrue(_world.group_count > ref_pg_count)
@instantiate_parametrized_tests @instantiate_parametrized_tests
class TestMeshEnv(DTensorTestBase): class TestMeshEnv(DTensorTestBase):
@ -482,9 +501,11 @@ class DeviceMeshCollectiveTest(DTensorTestBase):
torch.chunk(big_tensor, device_mesh.size(), dim=shard_dim) torch.chunk(big_tensor, device_mesh.size(), dim=shard_dim)
) )
unpadded_list = [ unpadded_list = [
shard_placement._unpad_tensor(big_tensor_chunks[i], pad_sizes[i]) (
if pad_sizes[i] > 0 shard_placement._unpad_tensor(big_tensor_chunks[i], pad_sizes[i])
else big_tensor_chunks[i] if pad_sizes[i] > 0
else big_tensor_chunks[i]
)
for i, big_tensor in enumerate(big_tensor_chunks) for i, big_tensor in enumerate(big_tensor_chunks)
] ]
all_gathered_tensor = torch.cat(unpadded_list, dim=shard_dim) all_gathered_tensor = torch.cat(unpadded_list, dim=shard_dim)

View File

@ -60,6 +60,7 @@ else:
def __init__(self) -> None: def __init__(self) -> None:
self.mesh_stack: List[DeviceMesh] = [] self.mesh_stack: List[DeviceMesh] = []
self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {} self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
self.parent_to_child_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {}
def get_current_mesh(self) -> "DeviceMesh": def get_current_mesh(self) -> "DeviceMesh":
if len(self.mesh_stack) == 0: if len(self.mesh_stack) == 0:
@ -69,6 +70,13 @@ else:
def create_child_mesh( def create_child_mesh(
self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str
) -> "DeviceMesh": ) -> "DeviceMesh":
# Directly return the child mesh if it is already created.
child_mesh_mappings = self.parent_to_child_mapping.get(device_mesh)
if child_mesh_mappings:
sub_mesh = child_mesh_mappings.get(mesh_dim_name)
if sub_mesh:
return sub_mesh
# swap the current dim to the last dim then reshape to flatten out other # swap the current dim to the last dim then reshape to flatten out other
# dims, so we can just extract the list of ranks which contains cur_rank. # dims, so we can just extract the list of ranks which contains cur_rank.
cur_rank = device_mesh.get_rank() cur_rank = device_mesh.get_rank()
@ -88,6 +96,9 @@ else:
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined] res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined]
# Assign the current DeviceMesh as the parent of the child DeviceMesh. # Assign the current DeviceMesh as the parent of the child DeviceMesh.
self.child_to_parent_mapping[res_sub_mesh] = device_mesh self.child_to_parent_mapping[res_sub_mesh] = device_mesh
self.parent_to_child_mapping.setdefault(device_mesh, {})[
mesh_dim_name
] = res_sub_mesh
return res_sub_mesh return res_sub_mesh
def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]: def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
@ -378,7 +389,6 @@ else:
mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim_name) mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim_name)
submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name) submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)
return submesh return submesh
def get_group( def get_group(