mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
[dtensor] delete the old unused mesh_alltoall (#124879)
as titled, as we have a dedicated comm op, this is not needed anymore Pull Request resolved: https://github.com/pytorch/pytorch/pull/124879 Approved by: https://github.com/XilunWu, https://github.com/wz337 ghstack dependencies: #124871, #124872
This commit is contained in:
parent
00df0d3e94
commit
04a241947a
|
|
@ -6,7 +6,6 @@ import torch
|
|||
import torch.distributed._functional_collectives as funcol
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed._tensor._collective_utils import (
|
||||
mesh_all_to_all,
|
||||
mesh_broadcast,
|
||||
mesh_scatter,
|
||||
unpad_tensor,
|
||||
|
|
@ -700,70 +699,6 @@ class DeviceMeshCollectiveTest(DTensorTestBase):
|
|||
mesh_scatter(received_tensor, scattered_tensors, mesh, mesh_dim=dim)
|
||||
self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank)
|
||||
|
||||
@with_comms
|
||||
def test_all_to_all_1d(self):
|
||||
# transpose on a 2D tensor distributed over N nodes:
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
tensor_shape = [3, 3]
|
||||
input_tensor_list = [
|
||||
torch.ones(*tensor_shape, device=self.device_type)
|
||||
* (rank + self.rank * self.world_size)
|
||||
for rank in range(self.world_size)
|
||||
]
|
||||
expected_tensor_list = [
|
||||
torch.ones(tensor_shape, device=self.device_type)
|
||||
* (self.rank + rank * self.world_size) # i.e. transpose
|
||||
for rank in range(self.world_size)
|
||||
]
|
||||
for scatter_dim in range(len(tensor_shape)):
|
||||
output_tensor_list = [
|
||||
torch.empty_like(input_tensor_list[idx])
|
||||
for idx in range(len(input_tensor_list))
|
||||
]
|
||||
# scatter on dim > 0 would generate non-contiguous tensor, verify that works
|
||||
mesh_all_to_all(output_tensor_list, input_tensor_list, mesh, mesh_dim=0)
|
||||
output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)
|
||||
expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim)
|
||||
|
||||
self.assertEqual(output_tensor, expected_tensor)
|
||||
|
||||
@with_comms
|
||||
def test_all_to_all_nd(self):
|
||||
mesh_tensor = torch.arange(8).reshape(2, 2, 2)
|
||||
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
||||
tensor_shape = [3, 3, 3]
|
||||
# check all dim groups
|
||||
dim_to_subgroups = mesh.get_group()
|
||||
for dim, dim_group in enumerate(dim_to_subgroups):
|
||||
my_coordinate = mesh.get_coordinate()[dim]
|
||||
dim_group_size = get_world_size(dim_group)
|
||||
global_ranks = [
|
||||
get_global_rank(dim_group, i) for i in range(dim_group_size)
|
||||
]
|
||||
input_tensor_list = [
|
||||
torch.ones(*tensor_shape, device=self.device_type)
|
||||
* (i + self.rank * dim_group_size)
|
||||
for i in range(dim_group_size)
|
||||
]
|
||||
expected_tensor_list = [
|
||||
torch.ones(*tensor_shape, device=self.device_type)
|
||||
* (my_coordinate + global_rank * dim_group_size) # i.e. transpose
|
||||
for global_rank in global_ranks
|
||||
]
|
||||
for scatter_dim in range(len(tensor_shape)):
|
||||
# input_tensor = torch.cat(input_tensor_list, dim=scatter_dim)
|
||||
output_tensor_list = [
|
||||
torch.empty_like(input_tensor_list[idx])
|
||||
for idx in range(len(input_tensor_list))
|
||||
]
|
||||
# scatter on dim > 0 would generate non-contiguous tensor, verify that works
|
||||
mesh_all_to_all(
|
||||
output_tensor_list, input_tensor_list, mesh, mesh_dim=dim
|
||||
)
|
||||
output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)
|
||||
expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim)
|
||||
self.assertEqual(output_tensor, expected_tensor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -11,11 +11,9 @@ import torch.distributed._tensor.placement_types as placement_types
|
|||
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_get_group_size_by_name,
|
||||
all_to_all,
|
||||
broadcast,
|
||||
get_global_rank,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
GroupMember,
|
||||
ProcessGroup,
|
||||
scatter,
|
||||
|
|
@ -150,48 +148,6 @@ def mesh_broadcast(
|
|||
return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)
|
||||
|
||||
|
||||
# TODO: test uneven split on GLOO and NCCL
|
||||
def mesh_all_to_all(
|
||||
output_tensor_list: List[torch.Tensor],
|
||||
input_tensor_list: List[torch.Tensor],
|
||||
mesh: DeviceMesh,
|
||||
mesh_dim: int = 0,
|
||||
async_op: bool = False,
|
||||
) -> Optional[Work]:
|
||||
dim_group = mesh.get_group(mesh_dim)
|
||||
assert isinstance(dim_group, ProcessGroup)
|
||||
|
||||
work = None
|
||||
# no direct dist.all_to_all support on 'gloo' so we manually do scatters
|
||||
if mesh.device_type == "cpu":
|
||||
logger.warning(
|
||||
"ProcessGroupGloo does not support all_to_all, falling back with scatters!"
|
||||
)
|
||||
# TODO: pull the handle of uneven case in #492
|
||||
dim_group_size = get_world_size(dim_group)
|
||||
for i in range(dim_group_size):
|
||||
# src need to be global rank
|
||||
src_for_dim = i
|
||||
if dim_group is not GroupMember.WORLD:
|
||||
src_for_dim = get_global_rank(dim_group, i)
|
||||
|
||||
work = scatter(
|
||||
output_tensor_list[i],
|
||||
input_tensor_list if mesh.get_rank() == src_for_dim else [],
|
||||
group=dim_group,
|
||||
src=src_for_dim,
|
||||
async_op=async_op,
|
||||
)
|
||||
else:
|
||||
work = all_to_all(
|
||||
output_tensor_list,
|
||||
input_tensor_list,
|
||||
dim_group,
|
||||
async_op=async_op,
|
||||
)
|
||||
return work
|
||||
|
||||
|
||||
def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor:
|
||||
if pad_size == 0:
|
||||
return tensor
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user