mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[DeviceMesh] Cache and reuse sliced result (#122975)
Fixes #118849 Add a map for parent_to_child_mappings in _mesh_resources so we can cache and reuse submesh slicing result so that we can avoid recreating submesh and the underlying sub pg repeatedly, which could lead to funky behaviors. We will follow up with reusing pg from the parent_mesh during submesh creation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122975 Approved by: https://github.com/wanchaol
This commit is contained in:
parent
35c493f2cf
commit
2b1ba0ceae
|
|
@ -14,6 +14,7 @@ from torch.distributed._tensor.placement_types import _Partial, Shard
|
|||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
|
||||
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_world,
|
||||
get_global_rank,
|
||||
get_world_size,
|
||||
init_process_group,
|
||||
|
|
@ -320,6 +321,24 @@ class TestDeviceMeshGetItem(DTensorTestBase):
|
|||
with self.assertRaisesRegex(RuntimeError, "Invalid mesh_dim_name"):
|
||||
dp_mesh = mesh["dim0"]
|
||||
|
||||
@with_comms
|
||||
@run_with_both_funcol_impls
|
||||
def test_cache_and_reuse_submesh_slice_result(self):
|
||||
mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp"))
|
||||
|
||||
dp_mesh = mesh["dp"]
|
||||
ref_pg_count = _world.group_count
|
||||
|
||||
# When we call the "dp" slice second time, it should not create any new pg.
|
||||
# As we are just using the cached result so the pg count should be the same.
|
||||
dp_mesh_2 = mesh["dp"]
|
||||
self.assertEqual(ref_pg_count, _world.group_count)
|
||||
|
||||
# When we call the "tp" slice, it should create a new pg, as the "tp" slice is called
|
||||
# for the first time.
|
||||
tp_mesh = mesh["tp"]
|
||||
self.assertTrue(_world.group_count > ref_pg_count)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestMeshEnv(DTensorTestBase):
|
||||
|
|
@ -482,9 +501,11 @@ class DeviceMeshCollectiveTest(DTensorTestBase):
|
|||
torch.chunk(big_tensor, device_mesh.size(), dim=shard_dim)
|
||||
)
|
||||
unpadded_list = [
|
||||
shard_placement._unpad_tensor(big_tensor_chunks[i], pad_sizes[i])
|
||||
if pad_sizes[i] > 0
|
||||
else big_tensor_chunks[i]
|
||||
(
|
||||
shard_placement._unpad_tensor(big_tensor_chunks[i], pad_sizes[i])
|
||||
if pad_sizes[i] > 0
|
||||
else big_tensor_chunks[i]
|
||||
)
|
||||
for i, big_tensor in enumerate(big_tensor_chunks)
|
||||
]
|
||||
all_gathered_tensor = torch.cat(unpadded_list, dim=shard_dim)
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ else:
|
|||
def __init__(self) -> None:
|
||||
self.mesh_stack: List[DeviceMesh] = []
|
||||
self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
|
||||
self.parent_to_child_mapping: Dict[DeviceMesh, Dict[str, DeviceMesh]] = {}
|
||||
|
||||
def get_current_mesh(self) -> "DeviceMesh":
|
||||
if len(self.mesh_stack) == 0:
|
||||
|
|
@ -69,6 +70,13 @@ else:
|
|||
def create_child_mesh(
|
||||
self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str
|
||||
) -> "DeviceMesh":
|
||||
# Directly return the child mesh if it is already created.
|
||||
child_mesh_mappings = self.parent_to_child_mapping.get(device_mesh)
|
||||
if child_mesh_mappings:
|
||||
sub_mesh = child_mesh_mappings.get(mesh_dim_name)
|
||||
if sub_mesh:
|
||||
return sub_mesh
|
||||
|
||||
# swap the current dim to the last dim then reshape to flatten out other
|
||||
# dims, so we can just extract the list of ranks which contains cur_rank.
|
||||
cur_rank = device_mesh.get_rank()
|
||||
|
|
@ -88,6 +96,9 @@ else:
|
|||
res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] # type: ignore[possibly-undefined]
|
||||
# Assign the current DeviceMesh as the parent of the child DeviceMesh.
|
||||
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
|
||||
self.parent_to_child_mapping.setdefault(device_mesh, {})[
|
||||
mesh_dim_name
|
||||
] = res_sub_mesh
|
||||
return res_sub_mesh
|
||||
|
||||
def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
|
||||
|
|
@ -378,7 +389,6 @@ else:
|
|||
|
||||
mesh_dim = _mesh_resources.get_mesh_dim_by_name(self, mesh_dim_name)
|
||||
submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)
|
||||
|
||||
return submesh
|
||||
|
||||
def get_group(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user