# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import os import unittest from datetime import timedelta import torch import torch.distributed as dist import torch.distributed._functional_collectives as funcol from torch._C._distributed_c10d import Backend as C10dBackend from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed._mesh_layout import _MeshLayout as _Layout from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh from torch.distributed.distributed_c10d import ( _get_default_group, _world, get_global_rank, get_world_size, init_process_group, is_initialized, new_group, ProcessGroup, ) from torch.distributed.tensor import DTensor from torch.distributed.tensor._collective_utils import ( mesh_broadcast, mesh_scatter, unpad_tensor, ) from torch.distributed.tensor.placement_types import _Partial, Shard from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, ) from torch.testing._internal.distributed.fake_pg import FakeProcessGroup, FakeStore from torch.utils._typing_utils import not_none device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" device_count = torch.accelerator.device_count() try: import torch._C._distributed_c10d.ProcessGroupNCCL _NCCL_AVAILABLE = True except ImportError: _NCCL_AVAILABLE = False def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0, local_rank=-1): os.environ["MASTER_ADDR"] = addr os.environ["MASTER_PORT"] = port os.environ["WORLD_SIZE"] = f"{world_size}" os.environ["RANK"] = f"{rank}" if local_rank != -1: os.environ["LOCAL_RANK"] = f"{local_rank}" @unittest.skipIf(TEST_XPU, "XPU does not support gloo backend.") class DeviceMeshTestGlooBackend(DTensorTestBase): @property def backend(self): return "gloo" @with_comms def test_device_mesh_reuse_default_group(self): mesh = init_device_mesh(self.device_type, (self.world_size,)) mesh_group = mesh.get_group() default_group = _get_default_group() if torch.cuda.is_available(): self.assertNotEqual(mesh_group, default_group) self.assertEqual(get_world_size(mesh_group), get_world_size(default_group)) else: self.assertEqual(mesh_group, default_group) class DeviceMeshSetDeviceTest(DTensorTestBase): @property def world_size(self): return 4 @skip_if_lt_x_gpu(4) def test_manual_set_device(self): mesh_tensor = torch.arange(4).reshape(2, 2) self.assertTrue(not is_initialized()) # Set the device on each process before DeviceMesh constructor, # and device to be different than the default world rank torch.accelerator.set_device_index((self.rank + 2) % self.world_size) _set_env_var(world_size=self.world_size, rank=self.rank) DeviceMesh(self.device_type, mesh_tensor) self.assertTrue(is_initialized()) # check that the device is set to the correct device # and respect the previous set_device calls self.assertEqual( torch.accelerator.current_device_idx(), (self.rank + 2) % self.world_size ) self.destroy_pg() @skip_if_lt_x_gpu(4) def test_auto_set_device_from_local_rank(self): mesh_tensor = torch.arange(4).reshape(2, 2) self.assertTrue(not is_initialized()) # set the local rank to be different than the default world rank, # DeviceMesh should respect LOCAL_RANK env var if it's set local_rank = (self.rank + 1) % self.world_size _set_env_var( world_size=self.world_size, rank=self.rank, local_rank=local_rank, ) DeviceMesh(self.device_type, mesh_tensor) self.assertTrue(is_initialized()) # check that the device is set to the correct device # and respect the LOCAL_RANK env var self.assertEqual(torch.accelerator.current_device_idx(), local_rank) self.destroy_pg() @skip_if_lt_x_gpu(4) def test_auto_set_device_from_heuristic(self): mesh_tensor = torch.arange(4).reshape(2, 2) self.assertTrue(not is_initialized()) _set_env_var( world_size=self.world_size, rank=self.rank, ) with self.assertWarnsRegex( UserWarning, "It seems like you did not set/select the default device" ): DeviceMesh(self.device_type, mesh_tensor) self.assertTrue(is_initialized()) # check that the device is set to the correct device self.assertEqual(torch.accelerator.current_device_idx(), self.rank) self.destroy_pg() class DeviceMeshTest(DTensorTestBase): @property def world_size(self): return 4 @skip_if_lt_x_gpu(4) def test_init_process_group(self): mesh_tensor = torch.arange(4).reshape(2, 2) self.assertTrue(not is_initialized()) _set_env_var(world_size=self.world_size, rank=self.rank) DeviceMesh(self.device_type, mesh_tensor) self.assertTrue(is_initialized()) self.destroy_pg(self.rank) @with_comms @skip_if_lt_x_gpu(4) def test_assert_invalid_mesh_tensor(self): mesh = torch.arange(self.world_size).to(self.rank) with self.assertRaises(ValueError): DeviceMesh(self.device_type, mesh) @with_comms() def test_2d_mesh_non_eager_init_subgroup(self): mesh_shape = (2, self.world_size // 2) mesh_2d = init_device_mesh(self.device_type, mesh_shape) self.assertEqual(mesh_2d.get_group(0).bound_device_id, None) self.assertEqual(mesh_2d.get_group(1).bound_device_id, None) # TODO: need to refactor the other tests in this file to test both # eager_init=True and eager_init=False scenarios. @with_comms(eager_init=True) def test_2d_mesh_eager_init_subgroup(self): mesh_shape = (2, self.world_size // 2) mesh_2d = init_device_mesh(self.device_type, mesh_shape) # when eager init is used, the subgroup is created from nccl comm split and # there would be bound_device_id immediately assigned for the subgroup. if self.backend == "nccl": curr_device = torch.cuda.current_device() self.assertEqual(mesh_2d.get_group(0).bound_device_id.index, curr_device) self.assertEqual(mesh_2d.get_group(1).bound_device_id.index, curr_device) @with_comms() def test_get_group_and_get_all_groups(self): mesh_shape = (2, self.world_size // 2) mesh_2d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=("dp", "tp") ) tp_mesh = mesh_2d["tp"] dp_mesh = mesh_2d["dp"] self.assertEqual(mesh_2d.get_group(0), mesh_2d.get_group("dp")) self.assertEqual(mesh_2d.get_group(1), mesh_2d.get_group("tp")) self.assertEqual(mesh_2d.get_group("dp"), dp_mesh.get_group()) self.assertEqual(mesh_2d.get_group("tp"), tp_mesh.get_group()) groups = mesh_2d.get_all_groups() self.assertEqual(len(groups), 2) self.assertTrue(tp_mesh.get_group() in groups) self.assertTrue(dp_mesh.get_group() in groups) @with_comms def test_get_local_rank_raises_exception(self): mesh_shape = (2, self.world_size // 2) mesh_2d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=("dp", "tp") ) with self.assertRaisesRegex( RuntimeError, "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", ): mesh_2d.get_local_rank() @with_comms def test_get_local_rank(self): mesh_shape = (2, self.world_size // 2) mesh_2d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=("dp", "tp") ) self.assertEqual(mesh_2d.get_local_rank("dp"), mesh_2d.get_local_rank(0)) self.assertEqual(mesh_2d.get_local_rank("tp"), mesh_2d.get_local_rank(1)) dp_mesh = mesh_2d["dp"] tp_mesh = mesh_2d["tp"] self.assertEqual(dp_mesh.get_local_rank(), mesh_2d.get_local_rank("dp")) self.assertEqual(tp_mesh.get_local_rank(), mesh_2d.get_local_rank("tp")) # Verify flattened mesh local rank correctness. flattened_mesh = mesh_2d["dp", "tp"]._flatten() self.assertEqual(flattened_mesh.get_local_rank(), self.rank) @with_comms def test_device_mesh_2d(self): mesh_tensor = torch.arange(4).reshape(2, 2) # construct a device mesh for self.device_type mesh = DeviceMesh(self.device_type, mesh_tensor) # check all dim groups dim_to_subgroups = mesh.get_all_groups() expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]] for dim, dim_group in enumerate(dim_to_subgroups): self.assertTrue(dim < 2) dim_ranks = expected_ranks_by_dim[dim] dim_group_size = get_world_size(dim_group) self.assertIsInstance(dim_group, ProcessGroup) self.assertEqual(dim_group_size, 2) global_ranks = [ get_global_rank(dim_group, i) for i in range(dim_group_size) ] current_rank_expected_group_ranks = ( dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1] ) self.assertEqual(global_ranks, current_rank_expected_group_ranks) @with_comms def test_device_mesh_init_backend(self): mesh = DeviceMesh( self.device_type, torch.arange(10), _init_backend=False, _rank=5 ) with self.assertRaisesRegex(RuntimeError, "process groups not initialized!"): mesh.get_group() # coordinates should always been populated when init_backend is False, as whenever # we call init_backend we should make sure the default pg already created self.assertEqual(mesh.get_coordinate(), [5]) def test_fake_pg_device_mesh(self): fake_store = FakeStore() init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size) device_type = ( torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" ) mesh = DeviceMesh(device_type, torch.arange(self.world_size)) local_tensor = torch.randn(2, 8) global_tensor = funcol.all_gather_tensor( local_tensor, gather_dim=0, group=(mesh, 0) ).wait() self.assertEqual(global_tensor.shape, (self.world_size * 2, 8)) @with_comms def test_from_group_with_global_pg(self): # Simple test: check `from_group` from a mesh pg vs. directly # initializing via `init_device_mesh` ref_global_mesh = init_device_mesh(self.device_type, (self.world_size,)) mesh_pg = ref_global_mesh.get_group() global_mesh = DeviceMesh.from_group(mesh_pg, self.device_type) self.assertEqual(ref_global_mesh, global_mesh) self.assertEqual(ref_global_mesh._dim_group_names, global_mesh._dim_group_names) self.assertEqual( ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim ) # Check when `mesh` is passed as well global_mesh = DeviceMesh.from_group( mesh_pg, self.device_type, mesh=torch.arange(self.world_size) ) self.assertEqual(ref_global_mesh, global_mesh) self.assertEqual(ref_global_mesh._dim_group_names, global_mesh._dim_group_names) self.assertEqual( ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim ) @with_comms def test_from_group_with_invalid_mesh(self): global_pg = _get_default_group() global_pg_size = global_pg.size() assert global_pg_size == 4, "Test assumes global world size of 4" invalid_mesh = [[0, 1], [2, 3]] # 2D mesh when we need 1D regex = r"Invalid mesh \[\[0, 1\], \[2, 3\]\] for ProcessGroup with ranks \[0, 1, 2, 3\]" with self.assertRaisesRegex(ValueError, regex): DeviceMesh.from_group( global_pg, device_type, invalid_mesh, mesh_dim_names=("dim0", "dim1") ) device_mesh = init_device_mesh(self.device_type, (2, 2)) groups = device_mesh.get_all_groups() invalid_mesh = (0, 1, 2, 3) # 1D mesh when we need 2D regex = r"Expects mesh with ndim equal to number of ProcessGroups but got mesh \[0, 1, 2, 3\] and 2 ProcessGroups" with self.assertRaisesRegex(ValueError, regex): DeviceMesh.from_group( groups, self.device_type, invalid_mesh, mesh_dim_names=("dim0", "dim1") ) def test_raises_invalid_device_type(self): with self.assertRaisesRegex( RuntimeError, "Device type with index is not supported", ): # test init_device_mesh with an invalid device type that contains a GPU index mesh_shape = (2, self.world_size // 2) init_device_mesh( f"{device_type}:0", mesh_shape=mesh_shape, mesh_dim_names=("dp", "tp") ) @with_comms def test_get_root_mesh_multiple_independent_meshes(self): # regression test for issue #163330 # when creating multiple independent device meshes and slicing them, # get_root_mesh should return the correct parent mesh for each submesh mesh1 = init_device_mesh( self.device_type, (2, 2), mesh_dim_names=("dp", "tp"), ) mesh1_dp = mesh1["dp"] mesh1_tp = mesh1["tp"] mesh2 = init_device_mesh( self.device_type, (2, 2), mesh_dim_names=("dim1", "dim2"), ) mesh2_dim1 = mesh2["dim1"] mesh2_dim2 = mesh2["dim2"] self.assertEqual(_mesh_resources.get_root_mesh(mesh1_dp), mesh1) self.assertEqual(_mesh_resources.get_root_mesh(mesh1_tp), mesh1) self.assertEqual(_mesh_resources.get_root_mesh(mesh2_dim1), mesh2) self.assertEqual(_mesh_resources.get_root_mesh(mesh2_dim2), mesh2) self.assertNotEqual(_mesh_resources.get_root_mesh(mesh1_dp), mesh2) self.assertNotEqual(_mesh_resources.get_root_mesh(mesh1_tp), mesh2) class DeviceMeshTestNDim(DTensorTestBase): @property def world_size(self): return 8 @with_comms def test_device_mesh_nd(self): # construct a device mesh for self.device_type mesh_tensor = torch.arange(8).reshape(2, 2, 2) mesh = DeviceMesh(self.device_type, mesh_tensor) # check all dim groups dim_to_subgroups = mesh.get_all_groups() for dim, dim_group in enumerate(dim_to_subgroups): self.assertTrue(dim < mesh_tensor.ndim) dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2) dim_group_size = get_world_size(dim_group) self.assertIsInstance(dim_group, ProcessGroup) self.assertEqual(dim_group_size, 2) global_ranks = [ get_global_rank(dim_group, i) for i in range(dim_group_size) ] for ranks in dim_ranks: if self.rank in ranks: self.assertEqual(global_ranks, ranks.tolist()) @with_comms def test_device_mesh_hash(self): mesh_tensor_2d = torch.arange(8).reshape(4, 2) mesh = DeviceMesh(self.device_type, mesh_tensor_2d) mesh2 = DeviceMesh(self.device_type, mesh_tensor_2d) self.assertEqual(hash(mesh), hash(mesh2)) mesh_tensor_3d = torch.arange(8).reshape(2, 2, 2) mesh3 = DeviceMesh(self.device_type, mesh_tensor_3d) self.assertNotEqual(hash(mesh), hash(mesh3)) self.assertNotEqual(hash(mesh2), hash(mesh3)) @with_comms def test_get_local_rank_3d(self): """ If we have a 3D mesh and we want to apply dp, pp, tp to it, mesh_dim_names = ["dp", "pp", "tp"], and the mesh tensor would be: mesh_3d_tensor = [ [ [0, 1], [2, 3], ], [ [4, 5], [6, 7], ] ] """ mesh_shape = (2, 2, 2) mesh_3d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=("dp", "pp", "tp") ) # tp_rank_0: [0, 2, 4, 6], tp_rank_1: [1, 3, 5, 7] tp_rank = mesh_3d.get_local_rank("tp") expected_tp_rank = self.rank % 2 self.assertEqual(tp_rank, expected_tp_rank) # pp_rank_0: [0, 1, 4, 5], pp_rank_1: [2, 3, 6, 7] pp_rank = mesh_3d.get_local_rank("pp") expected_pp_rank = 0 if self.rank % 4 <= 1 else 1 self.assertEqual(pp_rank, expected_pp_rank) # dp_rank_0: [0, 1, 2, 3], dp_rank_1: [4, 5, 6, 7] dp_rank = mesh_3d.get_local_rank("dp") expected_dp_rank = self.rank // 4 self.assertEqual(dp_rank, expected_dp_rank) @with_comms def test_device_mesh_parent_child_hash(self): mesh_2d = init_device_mesh( self.device_type, (2, self.world_size // 2), mesh_dim_names=("DP", "TP") ) mesh_group_1 = torch.arange(0, self.world_size // 2) mesh_group_2 = torch.arange(self.world_size // 2, self.world_size) ep_mesh_1 = DeviceMesh(self.device_type, mesh_group_1) ep_mesh_2 = DeviceMesh(self.device_type, mesh_group_2) ep_mesh = ep_mesh_1 if self.rank < self.world_size // 2 else ep_mesh_2 # ep_mesh is considered different from mesh_2d["TP"] self.assertEqual(mesh_2d["TP"]._flatten_mesh_list, ep_mesh._flatten_mesh_list) self.assertEqual(mesh_2d["TP"]._layout, ep_mesh._layout) self.assertEqual(mesh_2d["TP"].mesh.shape, ep_mesh.mesh.shape) self.assertEqual(mesh_2d["TP"].device_type, ep_mesh.device_type) self.assertNotEqual(mesh_2d["TP"].mesh_dim_names, ep_mesh.mesh_dim_names) self.assertEqual(mesh_2d["TP"]._thread_id, ep_mesh._thread_id) self.assertNotEqual(hash(mesh_2d["TP"]), hash(ep_mesh)) self.assertNotEqual(mesh_2d["TP"], ep_mesh) another_mesh_1 = DeviceMesh(self.device_type, mesh_group_1) another_mesh_2 = DeviceMesh(self.device_type, mesh_group_2) another_mesh = ( another_mesh_1 if self.rank < self.world_size // 2 else another_mesh_2 ) # another_mesh is considered the same as ep_mesh self.assertEqual(ep_mesh._flatten_mesh_list, another_mesh._flatten_mesh_list) self.assertEqual(ep_mesh._layout, another_mesh._layout) self.assertEqual(ep_mesh.mesh.shape, another_mesh.mesh.shape) self.assertEqual(ep_mesh.device_type, another_mesh.device_type) self.assertEqual(ep_mesh.mesh_dim_names, another_mesh.mesh_dim_names) self.assertEqual(ep_mesh._thread_id, another_mesh._thread_id) self.assertEqual(hash(ep_mesh), hash(another_mesh)) self.assertEqual(ep_mesh, another_mesh) @with_comms def test_from_group_with_mesh_shape_3d(self): """Tests ``from_group`` when passing ``mesh_shape`` as 3D.""" # Consider the following 3D scenario and we need to create the 2D HSDP mesh from it. # - (2, 2, 2) ("dp_replicate", "dp_shard", "tp") mesh mesh_shape = (2, 2, 2) mesh_dim_names = ("dp_replicate", "dp_shard", "tp") ref_mesh = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) dp_shard_group = ref_mesh["dp_shard"].get_group() dp_replicate_group = ref_mesh["dp_replicate"].get_group() dp_mesh = DeviceMesh.from_group( [dp_replicate_group, dp_shard_group], self.device_type, mesh=ref_mesh.mesh[:, :, ref_mesh.get_local_rank(mesh_dim="tp")], mesh_dim_names=("dp_replicate", "dp_shard"), ) ref_mesh_dp_dim_group_names = ref_mesh._dim_group_names[:2] self.assertEqual(ref_mesh_dp_dim_group_names, dp_mesh._dim_group_names[:2]) # Cannot check directly for mesh equality since parent meshes are not # the same since the ref's parent mesh is 3D self.assertEqual(dp_mesh["dp_replicate"].mesh, ref_mesh["dp_replicate"].mesh) self.assertEqual( dp_mesh["dp_replicate"]._dim_group_names, ref_mesh["dp_replicate"]._dim_group_names, ) self.assertEqual(dp_mesh["dp_shard"].mesh, ref_mesh["dp_shard"].mesh) self.assertEqual( dp_mesh["dp_shard"]._dim_group_names, ref_mesh["dp_shard"]._dim_group_names, ) @with_comms() def test_from_group_with_mesh_shape_2d(self): """Tests ``from_group`` when passing ``mesh_shape`` as 2D.""" # Consider the following scenario where the process group has been created, # but we need to create the 2D HSDP mesh from it later in the program. mesh_shape = (2, 4) mesh_dim_names = ("dp_replicate", "dp_shard") ref_mesh = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) # Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7)) # and assign the correct shard group to each rank shard_rank_lists = ( list(range(0, self.world_size // 2)), list(range(self.world_size // 2, self.world_size)), ) shard_groups = ( new_group(shard_rank_lists[0]), new_group(shard_rank_lists[1]), ) current_shard_group = ( shard_groups[0] if self.rank in shard_rank_lists[0] else shard_groups[1] ) # Create replicate groups (for example, (0, 4), (1, 5), (2, 6), (3, 7)) # and assign the correct replicate group to each rank current_replicate_group = None shard_factor = len(shard_rank_lists[0]) for i in range(self.world_size // 2): replicate_group_ranks = list(range(i, self.world_size, shard_factor)) replicate_group = new_group(replicate_group_ranks) if self.rank in replicate_group_ranks: current_replicate_group = replicate_group dp_mesh = DeviceMesh.from_group( [not_none(current_replicate_group), current_shard_group], self.device_type, mesh=ref_mesh.mesh, mesh_dim_names=("dp_replicate", "dp_shard"), ) for mesh_dim_group, ref_mesh_dim_group in zip( dp_mesh.get_all_groups(), ref_mesh.get_all_groups() ): mesh_dim_group_ranks = dist.get_process_group_ranks(mesh_dim_group) ref_mesh_dim_group_ranks = dist.get_process_group_ranks(ref_mesh_dim_group) self.assertEqual(mesh_dim_group_ranks, ref_mesh_dim_group_ranks) # check both the 2d mesh and the submeshes are exactly the same. self.assertEqual(dp_mesh, ref_mesh) self.assertEqual(dp_mesh["dp_replicate"], ref_mesh["dp_replicate"]) self.assertEqual(dp_mesh["dp_shard"], ref_mesh["dp_shard"]) class InitDeviceMeshTest(DTensorTestBase): @property def world_size(self): return 8 @with_comms def test_init_device_mesh(self): mesh_shape = (2, 4) mesh_dim_names = ("DP", "TP") ref_mesh = DeviceMesh( self.device_type, torch.arange(8).view(mesh_shape), mesh_dim_names=mesh_dim_names, ) # test init_device_mesh with mesh_dim_names mesh_2d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) self.assertEqual(mesh_2d, ref_mesh) self.assertEqual(mesh_2d.mesh_dim_names, mesh_dim_names) @with_comms def test_raises_duplicate_mesh_dim_names(self): with self.assertRaisesRegex( RuntimeError, "Each mesh_dim_name must be unique.", ): init_device_mesh( self.device_type, (2, 4), mesh_dim_names=["dp", "dp"], ) @with_comms def test_raises_mesh_shape_mesh_dim_names_mismatch(self): with self.assertRaisesRegex( RuntimeError, "mesh_shape and mesh_dim_names should have same length!", ): init_device_mesh( self.device_type, (8,), mesh_dim_names=["dp", "tp"], ) def _test_backend_override_argument_dict_with_idx_and_backend(self): opts = FakeProcessGroup.Options() opts.fake_option = 42 mesh = init_device_mesh( self.device_type, (2, 2, 2), mesh_dim_names=("dp", "tp", "cp"), backend_override={0: "fake", 2: ("fake", opts)}, ) def get_opts(mesh: DeviceMesh, dim_idx: int) -> C10dBackend.Options: return ( mesh.get_group(dim_idx) ._get_backend(torch.device(f"{self.device_type}:{self.rank}")) .options ) # Fake pg only have BackendType as BackendType::CUSTOM. self.assertEqual(mesh.get_group(0)._get_backend_name(), "custom") self.assertNotEqual(mesh.get_group(1)._get_backend_name(), "custom") self.assertEqual(mesh.get_group(2)._get_backend_name(), "custom") self.assertIsNone(get_opts(mesh, 0)) self.assertEqual(get_opts(mesh, 2).fake_option, 42) dp_tp_mesh = mesh["dp", "tp"]._flatten() dp_cp_mesh = mesh["dp", "cp"]._flatten(backend_override="fake") tp_cp_mesh = mesh["tp", "cp"]._flatten(backend_override=("fake", opts)) self.assertNotEqual(dp_tp_mesh.get_group(0)._get_backend_name(), "custom") self.assertEqual(dp_cp_mesh.get_group(0)._get_backend_name(), "custom") self.assertEqual(tp_cp_mesh.get_group(0)._get_backend_name(), "custom") self.assertIsNone(get_opts(dp_cp_mesh, 0)) self.assertEqual(get_opts(tp_cp_mesh, 0).fake_option, 42) @with_comms def test_backend_override_argument_dict_with_idx_and_backend_lazy(self): self._test_backend_override_argument_dict_with_idx_and_backend() @with_comms(eager_init=True) def test_backend_override_argument_dict_with_idx_and_backend_eager(self): self._test_backend_override_argument_dict_with_idx_and_backend() @with_comms(backend="fake") def test_backend_override_argument_dict_with_name_and_options(self): opts = FakeProcessGroup.Options() opts.fake_option = 42 mesh = init_device_mesh( self.device_type, (2, 2, 2), mesh_dim_names=("dp", "tp", "cp"), backend_override={"tp": opts}, ) def get_opts(mesh: DeviceMesh, dim_idx: int) -> C10dBackend.Options: return ( mesh.get_group(dim_idx) ._get_backend(torch.device(f"{self.device_type}:{self.rank}")) .options ) self.assertIsNone(get_opts(mesh, 0)) self.assertEqual(get_opts(mesh, 1).fake_option, 42) self.assertIsNone(get_opts(mesh, 2)) dp_tp_mesh = mesh["dp", "tp"]._flatten() dp_cp_mesh = mesh["dp", "cp"]._flatten(backend_override=opts) self.assertIsNone(get_opts(dp_tp_mesh, 0)) self.assertEqual(get_opts(dp_cp_mesh, 0).fake_option, 42) @with_comms def test_backend_override_argument_errors(self): with self.assertRaisesRegex( RuntimeError, "Found redundant dim index 0 and name dp in backend_override", ): init_device_mesh( self.device_type, (2, 4), mesh_dim_names=("dp", "tp"), backend_override={"dp": "foo", 0: "bar"}, ) with self.assertRaisesRegex( RuntimeError, r"Found invalid keys in backend_override: got \['cp'\]", ): init_device_mesh( self.device_type, (2, 4), mesh_dim_names=("dp", "tp"), backend_override={"cp": "foo"}, ) with self.assertRaisesRegex( RuntimeError, r"Found invalid keys in backend_override: got \[42\]", ): init_device_mesh( self.device_type, (2, 4), mesh_dim_names=("dp", "tp"), backend_override={42: "bar"}, ) class TestDeviceMeshGetItem(DTensorTestBase): @property def world_size(self): return 8 @with_comms def test_raises_no_mesh_dim_found(self): with self.assertRaisesRegex( RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names!" ): mesh = init_device_mesh(self.device_type, (2, 4)) mesh["DP"] @with_comms def test_raises_invalid_mesh_dim_name(self): child_mesh_dim_name = ("PP",) with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"): mesh_dim_names = ("DP", "TP") mesh = init_device_mesh( self.device_type, (2, 4), mesh_dim_names=mesh_dim_names, ) mesh[child_mesh_dim_name] @with_comms def test_get_item_2d(self): mesh_shape = (2, 4) mesh_dim_names = ("DP", "TP") mesh_2d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) pg_ranks_by_dim_name = {} for mesh_dim_name in mesh_dim_names: mesh_dim = mesh_dim_names.index(mesh_dim_name) pg_ranks_by_dim_name[mesh_dim_name] = mesh_2d.mesh.swapdims( -1, mesh_dim ).reshape(-1, mesh_2d.mesh.size(mesh_dim)) tp_mesh = mesh_2d["TP"] tp_group_idx = self.rank // 4 self.assertEqual(tp_mesh.mesh, pg_ranks_by_dim_name["TP"][tp_group_idx]) dp_group_idx = self.rank % 4 self.assertEqual(mesh_2d["DP"].mesh, pg_ranks_by_dim_name["DP"][dp_group_idx]) @with_comms def test_get_item_1d(self): mesh = init_device_mesh(self.device_type, (8,), mesh_dim_names=("dp",)) # Make sure slicing out 1D mesh from a 1D mesh works. dp_mesh = mesh["dp"] self.assertEqual(dp_mesh, mesh) with self.assertRaisesRegex(KeyError, "Invalid mesh_dim_name"): dp_mesh = mesh["dim0"] @with_comms def test_get_item_3d(self): mesh_shape = (2, 2, 2) mesh_dim_names = ("Replicate", "Shard", "TP") mesh_3d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) tp_group = [[0, 1], [2, 3], [4, 5], [6, 7]] tp_group_idx = int(self.rank / 2) self.assertEqual(mesh_3d["TP"].mesh.tolist(), tp_group[tp_group_idx]) shard_group = [[0, 2], [1, 3], [4, 6], [5, 7]] shard_group_idx = self.rank % 2 + self.rank // 4 * 2 self.assertEqual(mesh_3d["Shard"].mesh.tolist(), shard_group[shard_group_idx]) replicate_group = [[0, 4], [1, 5], [2, 6], [3, 7]] replicate_group_idx = self.rank % 4 self.assertEqual( mesh_3d["Replicate"].mesh.tolist(), replicate_group[replicate_group_idx] ) # We support both UX for nD slicing. # mesh_3d[["Replicate", "Shard"]] or mesh_3d["Replicate", "Shard"] hsdp_mesh_1 = mesh_3d[["Replicate", "Shard"]] hsdp_mesh_2 = mesh_3d["Replicate", "Shard"] hsdp_group = [[[0, 2], [4, 6]], [[1, 3], [5, 7]]] hsdp_group_idx = self.rank % 2 self.assertEqual(hsdp_mesh_1.mesh.tolist(), hsdp_group[hsdp_group_idx]) self.assertEqual(hsdp_mesh_2.mesh.tolist(), hsdp_group[hsdp_group_idx]) self.assertEqual(hsdp_mesh_1, hsdp_mesh_2) # Test slicing out 1D mesh from a sub-2D mesh. shard_mesh = hsdp_mesh_2["Shard"] self.assertEqual(shard_mesh.mesh.tolist(), shard_group[shard_group_idx]) replicate_mesh = hsdp_mesh_2["Replicate"] self.assertEqual( replicate_mesh.mesh.tolist(), replicate_group[replicate_group_idx] ) @with_comms def test_cache_and_reuse_submesh_slice_result(self): mesh = init_device_mesh(self.device_type, (2, 4), mesh_dim_names=("dp", "tp")) ref_pg_count = _world.group_count # When we call the "dp" slice second time, it should not create any new pg. # As we are just using the cached result so the pg count should be the same. self.assertEqual(ref_pg_count, _world.group_count) # When we call the "tp" slice, it should not create a new pg, as the "tp" slice would # just reuse the parent mesh pg. mesh["tp"] self.assertEqual(_world.group_count, ref_pg_count) @with_comms def test_get_item_3d_noncontiguous_slicing(self): mesh_shape = (2, 2, 2) mesh_dim_names = ("dp", "pp", "cp") mesh_3d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) # Slice order simply decides which mesh_dim sits on which mesh_dim. # For dp_cp_mesh, cp mesh is the innermost dimension. dp_cp_mesh = mesh_3d["dp", "cp"] expected_mesh_tensor = ( torch.tensor([[0, 1], [4, 5]], dtype=torch.int) if self.rank in (0, 1, 4, 5) else torch.tensor([[2, 3], [6, 7]], dtype=torch.int) ) dp_local_rank = dp_cp_mesh.get_local_rank("dp") self.assertEqual(dp_cp_mesh.mesh, expected_mesh_tensor) cp_mesh = mesh_3d["cp"] # Check on the current dp_local_rank, whether the cp mesh tensor is the same. self.assertEqual(dp_cp_mesh.mesh[dp_local_rank], cp_mesh.mesh) with self.assertRaisesRegex( KeyError, "Invalid mesh_dim_names", ): mesh_3d["cp", "dp"] @with_comms def test_flatten_mesh_1d(self): mesh_shape = (4,) mesh_dim_names = ("default",) mesh_1d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) mesh_1d._flatten() @with_comms def test_flatten_mesh_3d(self): mesh_shape = (2, 2, 2) mesh_dim_names = ("dp", "cp", "tp") mesh_3d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) # Test flatten into an existing mesh_dim_name inside the mesh with self.assertRaisesRegex( ValueError, "already exists for submesh of the DeviceMesh", ): mesh_3d._flatten("dp") # Test flatten contiguous dims dp_cp_mesh = mesh_3d["dp", "cp"] flattened_dp_cp_mesh = dp_cp_mesh._flatten() self.assertEqual(dp_cp_mesh.mesh.flatten(), flattened_dp_cp_mesh.mesh) self.assertEqual(flattened_dp_cp_mesh.mesh_dim_names[0], "dp_cp") self.assertEqual(flattened_dp_cp_mesh.get_group().group_desc, "mesh_dp_cp") root_mesh = dp_cp_mesh._get_root_mesh() self.assertEqual(root_mesh, mesh_3d) flatten_mesh_layout = root_mesh._flatten_mapping["dp_cp"]._layout self.assertEqual(flatten_mesh_layout, flattened_dp_cp_mesh._layout) self.assertEqual( flattened_dp_cp_mesh._layout.global_ranks(8), [[0, 2, 4, 6], [1, 3, 5, 7]], ) ref_pg_count = _world.group_count # Calling flatten again should not create a new pg. flattened_dp_cp_mesh_2 = dp_cp_mesh._flatten() self.assertEqual(flattened_dp_cp_mesh, flattened_dp_cp_mesh_2) self.assertEqual(ref_pg_count, _world.group_count) # Test flatten non-contiguous dims dp_tp_mesh = mesh_3d["dp", "tp"] flattened_dp_tp_mesh = dp_tp_mesh._flatten() self.assertEqual(dp_tp_mesh.mesh.flatten(), flattened_dp_tp_mesh.mesh) self.assertEqual(flattened_dp_tp_mesh.mesh_dim_names[0], "dp_tp") root_mesh = dp_tp_mesh._get_root_mesh() self.assertEqual(root_mesh, mesh_3d) flatten_mesh_root_layout = root_mesh._flatten_mapping["dp_tp"]._layout self.assertEqual(flatten_mesh_root_layout, flattened_dp_tp_mesh._layout) self.assertEqual( flattened_dp_tp_mesh._layout.global_ranks(8), [[0, 1, 4, 5], [2, 3, 6, 7]], ) with self.assertRaisesRegex( NotImplementedError, "Currently, this only allows slicing out a contiguous flattened dim", ): mesh_3d["dp_tp", "cp"] # Test flatten with a flattened mesh_dim_name cp_tp_mesh = mesh_3d["cp", "tp"] cp_tp_mesh._flatten("dummy") self.assertEqual(mesh_3d["dummy"].mesh_dim_names[0], "dummy") # Test flatten into an existing mesh_dim_name inside the mesh with self.assertRaisesRegex( ValueError, "dp already exists for submesh of the DeviceMesh", ): mesh_3d._flatten("dp") with self.assertRaisesRegex( ValueError, "Flatten mesh with mesh_dim_name dp_tp has been created before", ): mesh_3d["cp", "tp"]._flatten("dp_tp") @with_comms(eager_init=True) def test_flatten_mesh_4d(self): mesh_shape = (2, 2, 2, 1) mesh_dim_names = ("dp_replicate", "dp_shard", "cp", "tp") mesh_4d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) # flatten HSDP and CP into one mesh dp_cp_mesh = mesh_4d[mesh_dim_names[:3]]._flatten("dp_cp") # check flattened mesh integrity self.assertEqual(mesh_4d["dp_cp"].mesh.flatten(), dp_cp_mesh.mesh) # check flattened mesh dim names is correct self.assertEqual(dp_cp_mesh.mesh_dim_names, ("dp_cp",)) # check flattened mesh dependency self.assertEqual(dp_cp_mesh._get_root_mesh(), mesh_4d) @with_comms def test_unflatten_mesh_2d(self): mesh_shape = (4, 2) mesh_dim_names = ("dp", "tp") mesh_2d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) unflatten_mesh = mesh_2d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate")) self.assertEqual( unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "tp"] ) self.assertEqual(mesh_2d["tp"].mesh, unflatten_mesh["tp"].mesh) self.assertEqual(mesh_2d["tp"].get_group(), unflatten_mesh["tp"].get_group()) # Not supporting slicing out unflatten dim name from root mesh. with self.assertRaises(KeyError): self.assertEqual(mesh_2d["dp_shard"].mesh, unflatten_mesh["dp_shard"].mesh) @with_comms def test_unflatten_mesh_3d(self): # Test unflatten from a dummy world mesh, which is the case we need for Expert Parallelism(EP). global_mesh = init_device_mesh( self.device_type, (8,), mesh_dim_names=("world",), ) non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp")) ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp")) self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh) self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh) mesh_3d = global_mesh._unflatten(0, (4, 2, 1), ("dp", "cp", "tp")) unflatten_mesh = mesh_3d._unflatten(0, (2, 2), ("dp_shard", "dp_replicate")) self.assertEqual( unflatten_mesh.mesh_dim_names, ["dp_shard", "dp_replicate", "cp", "tp"] ) self.assertEqual(mesh_3d["tp"].mesh, unflatten_mesh["tp"].mesh) self.assertEqual(mesh_3d["tp"].get_group(), unflatten_mesh["tp"].get_group()) self.assertEqual(mesh_3d["cp"].mesh, unflatten_mesh["cp"].mesh) self.assertEqual(mesh_3d["cp"].get_group(), unflatten_mesh["cp"].get_group()) # Test unflatten with backend override set. if not _NCCL_AVAILABLE: return opts = dist.ProcessGroupNCCL.Options() opts._timeout = timedelta(seconds=30) mesh_2d = global_mesh._unflatten( 0, (1, 8), ("pp", "spmd"), backend_override={"pp": "fake", "spmd": ("nccl", opts)}, ) opts = dist.ProcessGroupNCCL.Options() opts._timeout = timedelta(seconds=60) mesh_4d = mesh_2d._unflatten( 1, (2, 2, 2), ("dp", "cp", "tp"), backend_override={"dp": "nccl", "cp": "nccl", "tp": ("nccl", opts)}, ) self.assertEqual(mesh_4d["pp"].get_group()._get_backend_name(), "custom") spmd_pg = mesh_2d["spmd"].get_group() self.assertEqual(spmd_pg._get_backend_name(), "nccl") w = spmd_pg.allreduce(torch.rand(10).cuda(self.rank)) self.assertTrue( spmd_pg._get_backend( torch.device(f"cuda:{self.rank}") )._verify_work_timeout(w, timedelta(seconds=30)) ) w.wait() tp_pg = mesh_4d["tp"].get_group() self.assertEqual(tp_pg._get_backend_name(), "nccl") w = tp_pg.allreduce(torch.rand(10).cuda(self.rank)) self.assertTrue( tp_pg._get_backend(torch.device(f"cuda:{self.rank}"))._verify_work_timeout( w, timedelta(seconds=60) ) ) w.wait() @with_comms def test_reconstruct_mesh_with_flatten_dim(self): mesh_3d = init_device_mesh( self.device_type, (2, 2, 2), mesh_dim_names=("replicate", "shard", "cp") ) shard_cp_mesh = mesh_3d["shard", "cp"]._flatten() hsdp_mesh = mesh_3d["replicate", "shard_cp"] expected_mesh_tensor = torch.tensor( [[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.int ) self.assertEqual(hsdp_mesh.mesh, expected_mesh_tensor) self.assertEqual(shard_cp_mesh.get_group(), mesh_3d["shard_cp"].get_group()) self.assertEqual( shard_cp_mesh.get_group(), mesh_3d.get_group(mesh_dim="shard_cp") ) mesh_3d = init_device_mesh( self.device_type, (2, 2, 2), mesh_dim_names=("dp", "cp", "tp") ) dp_cp_mesh = mesh_3d["dp", "cp"]._flatten() spmd_mesh = mesh_3d["dp_cp", "tp"] expected_mesh_tensor = torch.tensor( [[0, 1], [2, 3], [4, 5], [6, 7]], dtype=torch.int ) self.assertEqual(spmd_mesh.mesh, expected_mesh_tensor) self.assertEqual(dp_cp_mesh.get_group(), mesh_3d["dp_cp"].get_group()) self.assertEqual(dp_cp_mesh.get_group(), mesh_3d.get_group(mesh_dim="dp_cp")) class TestMeshEnv(DTensorTestBase): @property def world_size(self): return 8 @with_comms def test_get_root_mesh(self): mesh_3d = init_device_mesh( self.device_type, (2, 2, 2), mesh_dim_names=("dp", "cp", "tp"), ) dp_cp_mesh = mesh_3d["dp", "cp"] dp_tp_mesh = mesh_3d["dp", "tp"] cp_tp_mesh = mesh_3d["cp", "tp"] dp_mesh = mesh_3d["dp"] cp_mesh = mesh_3d["cp"] tp_mesh = mesh_3d["tp"] # Test BC case is still working self.assertEqual(_mesh_resources.get_root_mesh(dp_cp_mesh), mesh_3d) self.assertEqual(_mesh_resources.get_root_mesh(dp_tp_mesh), mesh_3d) self.assertEqual(_mesh_resources.get_root_mesh(cp_tp_mesh), mesh_3d) self.assertEqual(_mesh_resources.get_root_mesh(dp_mesh), mesh_3d) self.assertEqual(_mesh_resources.get_root_mesh(cp_mesh), mesh_3d) self.assertEqual(_mesh_resources.get_root_mesh(tp_mesh), mesh_3d) self.assertEqual(dp_cp_mesh._get_root_mesh(), mesh_3d) self.assertEqual(dp_tp_mesh._get_root_mesh(), mesh_3d) self.assertEqual(cp_tp_mesh._get_root_mesh(), mesh_3d) self.assertEqual(dp_mesh._get_root_mesh(), mesh_3d) self.assertEqual(cp_mesh._get_root_mesh(), mesh_3d) self.assertEqual(tp_mesh._get_root_mesh(), mesh_3d) @with_comms def test_get_root_mesh_dim_exist(self): mesh_shape = (2, self.world_size // 2) mesh_dim_names = ("DP", "TP") mesh_2d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) self.assertEqual(mesh_2d["DP"]._get_root_mesh_dim(), 0) self.assertEqual(mesh_2d["TP"]._get_root_mesh_dim(), 1) @with_comms def test_get_root_mesh_dim_not_exist(self): mesh_shape = (self.world_size,) mesh = init_device_mesh(self.device_type, mesh_shape) self.assertEqual(mesh._get_root_mesh_dim(), None) @with_comms def test_get_mesh_dim_by_name(self): mesh_shape = (2, self.world_size // 2) mesh_dim_names = ("DP", "TP") mesh_2d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) self.assertEqual(mesh_2d._get_mesh_dim_by_name("DP"), 0) self.assertEqual(mesh_2d._get_mesh_dim_by_name("TP"), 1) @with_comms def test_get_all_submeshes(self): mesh_2d = init_device_mesh( self.device_type, (2, 4), mesh_dim_names=("replicate", "shard"), ) all_submeshes = mesh_2d._get_all_submeshes("replicate") self.assertEqual(len(all_submeshes), 4) self.assertEqual( all(submesh.mesh.numel() == 2 for submesh in all_submeshes), True ) @with_comms def test_mesh_slice_fake_tensor_mode(self): mesh_shape = (2, self.world_size // 2) mesh_dim_names = ("DP", "TP") mesh_2d = init_device_mesh( self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names ) with FakeTensorMode(): mesh_2d["DP"] mesh_2d["TP"] mesh_2d["DP", "TP"] class DeviceMeshCollectiveTest(DTensorTestBase): @property def world_size(self): return 8 @with_comms def test_broadcast_1d(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank mesh_broadcast(local_tensor, mesh, mesh_dim=0) self.assertEqual(local_tensor, torch.zeros(3, 3)) @with_comms def test_scatter_1d(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) scatter_tensor_shape = [3, 3, 3] for scatter_dim in range(len(scatter_tensor_shape)): shard_placement = Shard(scatter_dim) scatter_tensor_shape[scatter_dim] *= self.world_size # make the random seed same across rank torch.manual_seed(0) global_tensor = torch.randn(scatter_tensor_shape, device=self.device_type) splitted_list, _ = shard_placement._split_tensor( global_tensor, mesh.size(), with_padding=True, contiguous=True ) recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()]) # scatter on dim > 0 would generate non-contiguous tensor, verify that works mesh_scatter(recv_tensor, splitted_list, mesh, mesh_dim=0) self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()]) @with_comms def test_scatter_uneven(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) my_rank = device_mesh.get_rank() tensor_to_split = torch.randn( device_mesh.size() + 3, device_mesh.size() + 1, device=self.device_type ) for shard_dim in range(tensor_to_split.ndim): shard_placement = Shard(shard_dim) tensor_to_scatter = tensor_to_split.clone() tensor_splitted_list = list( torch.chunk(tensor_to_split, self.world_size, dim=shard_dim) ) for _ in range(self.world_size - len(tensor_splitted_list)): tensor_splitted_list.append(torch.tensor([], device=self.device_type)) padded_tensor_list, pad_sizes = shard_placement._split_tensor( tensor_to_scatter, device_mesh.size(), with_padding=True, contiguous=True, ) scattered_tensor = torch.empty_like(padded_tensor_list[my_rank]) mesh_scatter(scattered_tensor, padded_tensor_list, device_mesh, mesh_dim=0) if pad_sizes[my_rank] != 0: scattered_tensor = unpad_tensor( scattered_tensor, shard_dim, pad_sizes[my_rank] ) if scattered_tensor.numel() == 0: # We need to check numel() instead of size if a tensor is ([]) after unpadding, # since the size could be ([0, 8]) after unpadding. self.assertEqual( scattered_tensor.numel(), tensor_splitted_list[my_rank].numel() ) else: self.assertEqual( scattered_tensor.size(), tensor_splitted_list[my_rank].size() ) self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank]) @with_comms def test_all_gather_uneven(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) my_rank = device_mesh.get_rank() tensor_to_split = torch.ones( device_mesh.size() + 3, device_mesh.size() + 1, device=self.device_type, ) for shard_dim in range(tensor_to_split.ndim): shard_placement = Shard(shard_dim) tensor_padded_list, pad_sizes = shard_placement._split_tensor( tensor_to_split, device_mesh.size(), with_padding=True, contiguous=True, ) local_tensor = tensor_padded_list[my_rank] big_tensor = funcol.all_gather_tensor( local_tensor, gather_dim=shard_dim, group=(device_mesh, 0) ) big_tensor_chunks = list( torch.chunk(big_tensor, device_mesh.size(), dim=shard_dim) ) unpadded_list = [ ( unpad_tensor(big_tensor, shard_dim, pad_sizes[i]) if pad_sizes[i] > 0 else big_tensor ) for i, big_tensor in enumerate(big_tensor_chunks) ] all_gathered_tensor = torch.cat(unpadded_list, dim=shard_dim) self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size()) self.assertEqual(all_gathered_tensor, tensor_to_split) @with_comms def test_reduce_scatter_contiguous(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) my_rank = device_mesh.get_rank() # Init the tensor step = self.world_size * 2 total_elem = step**2 tensor = torch.arange(0, total_elem).view(step, -1).to(device=self.device_type) tensor = tensor * (my_rank + 1) # Get non-contiguous tensor by slicing tensor_to_reduce = tensor[::2, :2] tensor_contiguous = tensor_to_reduce.clone().contiguous() # Partial to Shard to trigger reduce_scatter tensor_to_reduce = DTensor.from_local( tensor_to_reduce, device_mesh, [_Partial()] ) tensor_contiguous = DTensor.from_local( tensor_contiguous, device_mesh, [_Partial()] ) new_tensor = tensor_to_reduce.redistribute(device_mesh, [Shard(0)]) new_tensor_contiguous = tensor_contiguous.redistribute(device_mesh, [Shard(0)]) # The output for contiguous and non-contiguous tensors of the same value # should return the same reducescatter value. new_tensor_local = new_tensor._local_tensor new_tensor_contiguous_local = new_tensor_contiguous._local_tensor self.assertEqual(new_tensor_local, new_tensor_contiguous_local) self.assertEqual(list(new_tensor_local.size()), [1, 2]) # Check the reduce numerical value sum_base = (1 + self.world_size) * self.world_size / 2 first_elem = my_rank * sum_base * step * 2 expected_tensor = torch.tensor( [[first_elem, first_elem + sum_base]], dtype=new_tensor_local.dtype, device=self.device_type, ) self.assertEqual(new_tensor_local, expected_tensor) @with_comms def test_reduce_scatter_uneven(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) my_rank = device_mesh.get_rank() tensor_to_split = ( torch.ones( device_mesh.size() + 3, device_mesh.size() + 1, device=self.device_type, ) * self.rank ) for shard_dim in range(tensor_to_split.ndim): shard_placement = Shard(shard_dim) tensor_to_scatter = tensor_to_split.clone() tensor_splitted_list = list( torch.chunk(tensor_to_split, self.world_size, dim=shard_dim) ) for _ in range(self.world_size - len(tensor_splitted_list)): tensor_splitted_list.append(torch.tensor([], device=self.device_type)) padded_tensor_list, pad_sizes = shard_placement._split_tensor( tensor_to_scatter, device_mesh.size(), with_padding=True, contiguous=True, ) tensor_to_reduce = torch.cat(padded_tensor_list, shard_dim) res_num = ((0 + self.world_size - 1) * self.world_size) / 2 scattered_tensor = funcol.reduce_scatter_tensor( tensor_to_reduce, reduceOp="sum", scatter_dim=shard_dim, group=(device_mesh, 0), ) # unpad scattered_tensor if pad_sizes[my_rank] > 0: scattered_tensor = unpad_tensor( scattered_tensor, shard_dim, pad_sizes[my_rank] ) if scattered_tensor.numel() == 0: # We need to check numel() instead of size if a tensor is ([]) after unpadding, # since the size could be ([0, 8]) after unpadding. self.assertEqual( scattered_tensor.numel(), tensor_splitted_list[my_rank].numel() ) else: self.assertEqual( scattered_tensor.size(), tensor_splitted_list[my_rank].size() ) self.assertEqual( scattered_tensor, torch.ones_like(tensor_splitted_list[my_rank]) * res_num, ) @with_comms def test_broadcast_nd(self): mesh_tensor = torch.arange(8).reshape(2, 2, 2) mesh = DeviceMesh(self.device_type, mesh_tensor) local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank # check all dim groups dim_to_subgroups = mesh.get_all_groups() for dim, dim_group in enumerate(dim_to_subgroups): dim_group_size = get_world_size(dim_group) global_ranks = [ get_global_rank(dim_group, i) for i in range(dim_group_size) ] cloned_local_tensor = local_tensor.clone() mesh_broadcast(cloned_local_tensor, mesh, mesh_dim=dim) res_num = global_ranks[0] self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num) @with_comms def test_scatter_nd(self): mesh_tensor = torch.arange(8).reshape(2, 2, 2) mesh = DeviceMesh(self.device_type, mesh_tensor) # check all dim groups dim_to_subgroups = mesh.get_all_groups() for dim, dim_group in enumerate(dim_to_subgroups): dim_group_size = get_world_size(dim_group) global_ranks = [ get_global_rank(dim_group, i) for i in range(dim_group_size) ] scattered_tensors = [ torch.ones(3, 3, device=self.device_type) * global_rank for global_rank in global_ranks ] received_tensor = torch.empty_like( scattered_tensors[mesh.get_coordinate()[dim]] ) mesh_scatter(received_tensor, scattered_tensors, mesh, mesh_dim=dim) self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank) class CuTeLayoutTest(TestCase): def test_coalesce(self): # ((3,2),(2,1)) -> (6,1) l = _Layout((3, 2), (2, 1)) l = l.coalesce() self.assertEqual(list(l.sizes_and_strides), [(6, 1)]) # ((2,12),(3,4),(4,1)) -> (24,1) l = _Layout((2, 3, 4), (12, 4, 1)) l = l.coalesce() self.assertEqual(list(l.sizes_and_strides), [(24, 1)]) def test_coalesce_non_coalescible(self): # ((3,4),(2,1)) stays as-is (4 ≠ 2*1) l = _Layout((3, 2), (4, 1)) l = l.coalesce() self.assertEqual(list(l.sizes_and_strides), [(3, 4), (2, 1)]) def test_complement_n_group_layout(self): # complement((4,2), 8) = (2,1); together form (8,1) pg_layout = _Layout( (4,), (2,), ) outer = pg_layout.complement(world_size=8) self.assertEqual(list(outer.sizes_and_strides), [(2, 1)]) self.assertEqual( pg_layout.all_ranks_from_zero(), [0, 2, 4, 6], ) groups = [ [o + i for i in pg_layout.all_ranks_from_zero()] for o in outer.all_ranks_from_zero() ] self.assertEqual( groups, [ [0, 2, 4, 6], [1, 3, 5, 7], ], ) self.assertEqual( pg_layout.global_ranks(8), [ [0, 2, 4, 6], [1, 3, 5, 7], ], ) # complement((4,2), 16) = ((2,8), (2,1)); together form (16,1) outer = pg_layout.complement(world_size=16) self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 1)]) self.assertEqual( outer.all_ranks_from_zero(), [0, 1, 8, 9], ) self.assertEqual( pg_layout.global_ranks(16), [ [0, 2, 4, 6], [1, 3, 5, 7], [8, 10, 12, 14], [9, 11, 13, 15], ], ) # Complement ((2,4), (2,1)) under world_size=16 → complement ((2,8), (2,2)) pg_layout = _Layout((2, 2), (4, 1)) self.assertEqual( pg_layout.all_ranks_from_zero(), [0, 1, 4, 5], ) outer = pg_layout.complement(world_size=16) self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 2)]) self.assertEqual( outer.all_ranks_from_zero(), [0, 2, 8, 10], ) self.assertEqual( pg_layout.global_ranks(16), [ [0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15], ], ) # Test layout_to_global_ranks and layout_to_all_ranks_from_zero pg_layout = _Layout((2, 2), (4, 2)) self.assertEqual( pg_layout.all_ranks_from_zero(), [0, 2, 4, 6], ) self.assertEqual( pg_layout.global_ranks(16), [ [0, 2, 4, 6], [1, 3, 5, 7], [8, 10, 12, 14], [9, 11, 13, 15], ], ) outer = pg_layout.complement(world_size=16) self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 1)]) # Test when stride is not monotonically decreasing, the complement layout # is same as the one sorted its stride. pg_layout_r = _Layout((2, 2), (2, 4)) outer = pg_layout_r.complement(world_size=16) self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 1)]) self.assertEqual( pg_layout_r.global_ranks(16), [ [0, 4, 2, 6], [1, 5, 3, 7], [8, 12, 10, 14], [9, 13, 11, 15], ], ) # Test just all_ranks_from_zero and global_ranks. pg_layout = _Layout((4,), (2,)) self.assertEqual( pg_layout.all_ranks_from_zero(), [0, 2, 4, 6], ) self.assertEqual( pg_layout.global_ranks(16), [ [0, 2, 4, 6], [1, 3, 5, 7], [8, 10, 12, 14], [9, 11, 13, 15], ], ) def test_composition(self): # self = ((4,2), (2,1)), l = (2,1) → self o l = (2,1) orig_l = _Layout((4, 2), (2, 1)) right_l = _Layout((2,), (1,)) composed_layout = orig_l.composition(right_l) self.assertEqual(list(composed_layout.sizes_and_strides), [(2, 1)]) self.assertEqual( composed_layout.global_ranks(8), [ [0, 1], [2, 3], [4, 5], [6, 7], ], ) # self = (4,2), l = (2,1) → self o l = (2,2) orig_l = _Layout((4,), (2,)) right_l = _Layout((2,), (1,)) composed_layout = orig_l.composition(right_l) self.assertEqual(list(composed_layout.sizes_and_strides), [(2, 2)]) self.assertEqual( composed_layout.global_ranks(8), [ [0, 2], [1, 3], [4, 6], [5, 7], ], ) # self = (4,2), l = ((2,2), (2,1)) → self o l = ((2,4), (2,2)) # This is to mimic the un-flatten from a 2D mesh to a 1D mesh. right_l = _Layout((2, 2), (2, 1)) composed_layout = orig_l.composition(right_l) self.assertEqual(list(composed_layout.sizes_and_strides), [(2, 4), (2, 2)]) self.assertEqual( composed_layout[0].global_ranks(8), [ [0, 4], [1, 5], [2, 6], [3, 7], ], ) self.assertEqual( composed_layout[1].global_ranks(8), [ [0, 2], [1, 3], [4, 6], [5, 7], ], ) # Error case. orig_l = _Layout((4, 2), (4, 1)) with self.assertRaises( AssertionError, ): right_l = _Layout((2,), (3,)) orig_l.composition(right_l) def test_check_non_overlap(self): """Test the check_non_overlap method for various layout configurations.""" # Test 1: Valid layout - no overlap # sizes=(2,3), strides=(6,1) - stride 6 > span 3, so no overlap layout1 = _Layout((2, 3), (6, 1)) self.assertTrue(layout1.check_non_overlap()) # Test 2: Invalid layout - overlap due to stride < previous span # sizes=(2,3), strides=(2,1) - stride 2 < span 3, causes overlap layout2 = _Layout((2, 3), (2, 1)) self.assertFalse(layout2.check_non_overlap()) # Test 3: Invalid layout - duplicate strides # sizes=(2,3), strides=(1,1) - same stride, causes overlap layout3 = _Layout((2, 3), (1, 1)) self.assertFalse(layout3.check_non_overlap()) # Test 4: Valid layout - single dimension layout4 = _Layout((4,), (1,)) self.assertTrue(layout4.check_non_overlap()) # Test 5: Valid layout - exact boundary case # sizes=(2,3), strides=(3,1) - stride 3 == span 3, valid layout5 = _Layout((2, 3), (3, 1)) self.assertTrue(layout5.check_non_overlap()) # Test 6: Valid layout - multi-dimensional with proper spacing layout6 = _Layout((2, 2, 2), (8, 4, 1)) self.assertTrue(layout6.check_non_overlap()) # Test 7: Valid layout - stride not ordered layout7 = _Layout((2, 2, 2), (4, 1, 2)) self.assertTrue(layout7.check_non_overlap()) # Test 8: Valid layout - Interleaved but no overlap layout8 = _Layout((3, 2), (2, 3)) self.assertTrue(layout8.check_non_overlap()) def test_remap_to_tensor(self): """Test the remap_to_tensor method for various scenarios.""" # Test 1: Consecutive ranks, full world - should return logical groups directly original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int) layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2 result1 = layout1.remap_to_tensor(original_mesh) expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int) self.assertEqual(result1, expected1) # Test 2: Non-consecutive ranks - should map to actual ranks original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int) layout2 = _Layout((2, 2), (2, 1)) result2 = layout2.remap_to_tensor(original_mesh) expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int) self.assertEqual(result2, expected2) # Test 4: 1D layout with consecutive ranks original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int) layout4 = _Layout((4,), (1,)) result4 = layout4.remap_to_tensor(original_mesh) expected4 = torch.tensor([[0, 1, 2, 3]], dtype=torch.int) self.assertEqual(result4, expected4) # Test 5: Complex strided layout with non-consecutive ranks original_mesh = torch.tensor([5, 10, 15, 20], dtype=torch.int) layout5 = _Layout((2, 2), (2, 1)) result5 = layout5.remap_to_tensor(original_mesh) expected5 = torch.tensor([[[5, 10], [15, 20]]], dtype=torch.int) self.assertEqual(result5, expected5) # Test 6: Tensor Cute representation of a 2D mesh original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int) layout6 = _Layout((2, 2), (1, 2)) # column-major style result6 = layout6.remap_to_tensor(original_mesh) expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int) self.assertEqual(result6, expected6) # Test 7: Layout with different stride pattern original_mesh = torch.tensor([0, 2, 1, 4], dtype=torch.int) layout7 = _Layout((2, 2), (1, 2)) # column-major style result7 = layout7.remap_to_tensor(original_mesh) expected7 = torch.tensor([[[0, 1], [2, 4]]], dtype=torch.int) self.assertEqual(result7, expected7) if __name__ == "__main__": run_tests()