[dtensor] delete the old unused mesh_alltoall (#124879)

as titled, as we have a dedicated comm op, this is not needed anymore

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124879
Approved by: https://github.com/XilunWu, https://github.com/wz337
ghstack dependencies: #124871, #124872
This commit is contained in:
Wanchao Liang 2024-04-29 21:34:45 -07:00 committed by PyTorch MergeBot
parent 00df0d3e94
commit 04a241947a
2 changed files with 0 additions and 109 deletions

View File

@ -6,7 +6,6 @@ import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor
from torch.distributed._tensor._collective_utils import (
mesh_all_to_all,
mesh_broadcast,
mesh_scatter,
unpad_tensor,
@ -700,70 +699,6 @@ class DeviceMeshCollectiveTest(DTensorTestBase):
mesh_scatter(received_tensor, scattered_tensors, mesh, mesh_dim=dim)
self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank)
@with_comms
def test_all_to_all_1d(self):
# transpose on a 2D tensor distributed over N nodes:
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
tensor_shape = [3, 3]
input_tensor_list = [
torch.ones(*tensor_shape, device=self.device_type)
* (rank + self.rank * self.world_size)
for rank in range(self.world_size)
]
expected_tensor_list = [
torch.ones(tensor_shape, device=self.device_type)
* (self.rank + rank * self.world_size) # i.e. transpose
for rank in range(self.world_size)
]
for scatter_dim in range(len(tensor_shape)):
output_tensor_list = [
torch.empty_like(input_tensor_list[idx])
for idx in range(len(input_tensor_list))
]
# scatter on dim > 0 would generate non-contiguous tensor, verify that works
mesh_all_to_all(output_tensor_list, input_tensor_list, mesh, mesh_dim=0)
output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)
expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim)
self.assertEqual(output_tensor, expected_tensor)
@with_comms
def test_all_to_all_nd(self):
mesh_tensor = torch.arange(8).reshape(2, 2, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
tensor_shape = [3, 3, 3]
# check all dim groups
dim_to_subgroups = mesh.get_group()
for dim, dim_group in enumerate(dim_to_subgroups):
my_coordinate = mesh.get_coordinate()[dim]
dim_group_size = get_world_size(dim_group)
global_ranks = [
get_global_rank(dim_group, i) for i in range(dim_group_size)
]
input_tensor_list = [
torch.ones(*tensor_shape, device=self.device_type)
* (i + self.rank * dim_group_size)
for i in range(dim_group_size)
]
expected_tensor_list = [
torch.ones(*tensor_shape, device=self.device_type)
* (my_coordinate + global_rank * dim_group_size) # i.e. transpose
for global_rank in global_ranks
]
for scatter_dim in range(len(tensor_shape)):
# input_tensor = torch.cat(input_tensor_list, dim=scatter_dim)
output_tensor_list = [
torch.empty_like(input_tensor_list[idx])
for idx in range(len(input_tensor_list))
]
# scatter on dim > 0 would generate non-contiguous tensor, verify that works
mesh_all_to_all(
output_tensor_list, input_tensor_list, mesh, mesh_dim=dim
)
output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)
expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim)
self.assertEqual(output_tensor, expected_tensor)
if __name__ == "__main__":
run_tests()

View File

@ -11,11 +11,9 @@ import torch.distributed._tensor.placement_types as placement_types
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
from torch.distributed.distributed_c10d import (
_get_group_size_by_name,
all_to_all,
broadcast,
get_global_rank,
get_rank,
get_world_size,
GroupMember,
ProcessGroup,
scatter,
@ -150,48 +148,6 @@ def mesh_broadcast(
return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)
# TODO: test uneven split on GLOO and NCCL
def mesh_all_to_all(
output_tensor_list: List[torch.Tensor],
input_tensor_list: List[torch.Tensor],
mesh: DeviceMesh,
mesh_dim: int = 0,
async_op: bool = False,
) -> Optional[Work]:
dim_group = mesh.get_group(mesh_dim)
assert isinstance(dim_group, ProcessGroup)
work = None
# no direct dist.all_to_all support on 'gloo' so we manually do scatters
if mesh.device_type == "cpu":
logger.warning(
"ProcessGroupGloo does not support all_to_all, falling back with scatters!"
)
# TODO: pull the handle of uneven case in #492
dim_group_size = get_world_size(dim_group)
for i in range(dim_group_size):
# src need to be global rank
src_for_dim = i
if dim_group is not GroupMember.WORLD:
src_for_dim = get_global_rank(dim_group, i)
work = scatter(
output_tensor_list[i],
input_tensor_list if mesh.get_rank() == src_for_dim else [],
group=dim_group,
src=src_for_dim,
async_op=async_op,
)
else:
work = all_to_all(
output_tensor_list,
input_tensor_list,
dim_group,
async_op=async_op,
)
return work
def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
if pad_size == 0:
return tensor