From 04a241947ae9beecabed84bb36698552b82575f7 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 29 Apr 2024 21:34:45 -0700 Subject: [PATCH] [dtensor] delete the old unused mesh_alltoall (#124879) as titled, as we have a dedicated comm op, this is not needed anymore Pull Request resolved: https://github.com/pytorch/pytorch/pull/124879 Approved by: https://github.com/XilunWu, https://github.com/wz337 ghstack dependencies: #124871, #124872 --- test/distributed/test_device_mesh.py | 65 ------------------- .../distributed/_tensor/_collective_utils.py | 44 ------------- 2 files changed, 109 deletions(-) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 9c54cfa3125..d04fcf938c4 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -6,7 +6,6 @@ import torch import torch.distributed._functional_collectives as funcol from torch.distributed._tensor import DTensor from torch.distributed._tensor._collective_utils import ( - mesh_all_to_all, mesh_broadcast, mesh_scatter, unpad_tensor, @@ -700,70 +699,6 @@ class DeviceMeshCollectiveTest(DTensorTestBase): mesh_scatter(received_tensor, scattered_tensors, mesh, mesh_dim=dim) self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank) - @with_comms - def test_all_to_all_1d(self): - # transpose on a 2D tensor distributed over N nodes: - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) - tensor_shape = [3, 3] - input_tensor_list = [ - torch.ones(*tensor_shape, device=self.device_type) - * (rank + self.rank * self.world_size) - for rank in range(self.world_size) - ] - expected_tensor_list = [ - torch.ones(tensor_shape, device=self.device_type) - * (self.rank + rank * self.world_size) # i.e. transpose - for rank in range(self.world_size) - ] - for scatter_dim in range(len(tensor_shape)): - output_tensor_list = [ - torch.empty_like(input_tensor_list[idx]) - for idx in range(len(input_tensor_list)) - ] - # scatter on dim > 0 would generate non-contiguous tensor, verify that works - mesh_all_to_all(output_tensor_list, input_tensor_list, mesh, mesh_dim=0) - output_tensor = torch.cat(output_tensor_list, dim=scatter_dim) - expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim) - - self.assertEqual(output_tensor, expected_tensor) - - @with_comms - def test_all_to_all_nd(self): - mesh_tensor = torch.arange(8).reshape(2, 2, 2) - mesh = DeviceMesh(self.device_type, mesh_tensor) - tensor_shape = [3, 3, 3] - # check all dim groups - dim_to_subgroups = mesh.get_group() - for dim, dim_group in enumerate(dim_to_subgroups): - my_coordinate = mesh.get_coordinate()[dim] - dim_group_size = get_world_size(dim_group) - global_ranks = [ - get_global_rank(dim_group, i) for i in range(dim_group_size) - ] - input_tensor_list = [ - torch.ones(*tensor_shape, device=self.device_type) - * (i + self.rank * dim_group_size) - for i in range(dim_group_size) - ] - expected_tensor_list = [ - torch.ones(*tensor_shape, device=self.device_type) - * (my_coordinate + global_rank * dim_group_size) # i.e. transpose - for global_rank in global_ranks - ] - for scatter_dim in range(len(tensor_shape)): - # input_tensor = torch.cat(input_tensor_list, dim=scatter_dim) - output_tensor_list = [ - torch.empty_like(input_tensor_list[idx]) - for idx in range(len(input_tensor_list)) - ] - # scatter on dim > 0 would generate non-contiguous tensor, verify that works - mesh_all_to_all( - output_tensor_list, input_tensor_list, mesh, mesh_dim=dim - ) - output_tensor = torch.cat(output_tensor_list, dim=scatter_dim) - expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim) - self.assertEqual(output_tensor, expected_tensor) - if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_tensor/_collective_utils.py b/torch/distributed/_tensor/_collective_utils.py index 51c13796255..ce4809d996d 100644 --- a/torch/distributed/_tensor/_collective_utils.py +++ b/torch/distributed/_tensor/_collective_utils.py @@ -11,11 +11,9 @@ import torch.distributed._tensor.placement_types as placement_types from torch.distributed.device_mesh import _mesh_resources, DeviceMesh from torch.distributed.distributed_c10d import ( _get_group_size_by_name, - all_to_all, broadcast, get_global_rank, get_rank, - get_world_size, GroupMember, ProcessGroup, scatter, @@ -150,48 +148,6 @@ def mesh_broadcast( return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op) -# TODO: test uneven split on GLOO and NCCL -def mesh_all_to_all( - output_tensor_list: List[torch.Tensor], - input_tensor_list: List[torch.Tensor], - mesh: DeviceMesh, - mesh_dim: int = 0, - async_op: bool = False, -) -> Optional[Work]: - dim_group = mesh.get_group(mesh_dim) - assert isinstance(dim_group, ProcessGroup) - - work = None - # no direct dist.all_to_all support on 'gloo' so we manually do scatters - if mesh.device_type == "cpu": - logger.warning( - "ProcessGroupGloo does not support all_to_all, falling back with scatters!" - ) - # TODO: pull the handle of uneven case in #492 - dim_group_size = get_world_size(dim_group) - for i in range(dim_group_size): - # src need to be global rank - src_for_dim = i - if dim_group is not GroupMember.WORLD: - src_for_dim = get_global_rank(dim_group, i) - - work = scatter( - output_tensor_list[i], - input_tensor_list if mesh.get_rank() == src_for_dim else [], - group=dim_group, - src=src_for_dim, - async_op=async_op, - ) - else: - work = all_to_all( - output_tensor_list, - input_tensor_list, - dim_group, - async_op=async_op, - ) - return work - - def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: if pad_size == 0: return tensor