mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes one of the issue mentioned in #118639 @mvpatel2000 Pull Request resolved: https://github.com/pytorch/pytorch/pull/119752 Approved by: https://github.com/wanchaol
785 lines
30 KiB
Python
785 lines
30 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
import os
|
|
|
|
import torch
|
|
import torch.distributed._functional_collectives as funcol
|
|
from torch.distributed._tensor import DTensor
|
|
from torch.distributed._tensor._collective_utils import (
|
|
mesh_all_to_all,
|
|
mesh_broadcast,
|
|
mesh_scatter,
|
|
)
|
|
from torch.distributed._tensor.placement_types import _Partial, Shard
|
|
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
|
|
|
|
from torch.distributed.distributed_c10d import (
|
|
get_global_rank,
|
|
get_world_size,
|
|
init_process_group,
|
|
is_initialized,
|
|
is_nccl_available,
|
|
ProcessGroup,
|
|
)
|
|
from torch.testing._internal.common_distributed import run_with_both_funcol_impls
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
run_tests,
|
|
)
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
skip_if_lt_x_gpu,
|
|
skip_unless_torch_gpu,
|
|
with_comms,
|
|
)
|
|
from torch.testing._internal.distributed.fake_pg import FakeStore
|
|
|
|
|
|
def _get_device_type(world_size):
|
|
if (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.device_count() >= world_size
|
|
and is_nccl_available()
|
|
):
|
|
device_type = "cuda"
|
|
else:
|
|
device_type = "cpu"
|
|
return device_type
|
|
|
|
|
|
def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0):
|
|
os.environ["MASTER_ADDR"] = addr
|
|
os.environ["MASTER_PORT"] = port
|
|
os.environ["WORLD_SIZE"] = f"{world_size}"
|
|
os.environ["RANK"] = f"{rank}"
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class DeviceMeshTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self):
|
|
return 4
|
|
|
|
@run_with_both_funcol_impls
|
|
def test_init_process_group(self):
|
|
device_type = _get_device_type(self.world_size)
|
|
mesh_tensor = torch.arange(4).reshape(2, 2)
|
|
self.assertTrue(not is_initialized())
|
|
_set_env_var(world_size=self.world_size, rank=self.rank)
|
|
DeviceMesh(device_type, mesh_tensor)
|
|
self.assertTrue(is_initialized())
|
|
self.destroy_pg()
|
|
|
|
@with_comms
|
|
@skip_unless_torch_gpu
|
|
def test_assert_invalid_mesh_tensor(self):
|
|
mesh = torch.arange(self.world_size).to(self.rank)
|
|
with self.assertRaises(ValueError):
|
|
device_mesh = DeviceMesh(self.device_type, mesh)
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
@skip_unless_torch_gpu
|
|
def test_get_group(self):
|
|
# TODO: `test_get_group` still periodically timeout on cpu
|
|
# remove `@skip_unless_torch_gpu` after the problem is fixed.
|
|
mesh_shape = (2, self.world_size // 2)
|
|
mesh_2d = init_device_mesh(
|
|
self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
|
|
)
|
|
|
|
tp_mesh = mesh_2d["tp"]
|
|
dp_mesh = mesh_2d["dp"]
|
|
|
|
self.assertEqual(len(mesh_2d.get_group()), 2)
|
|
self.assertEqual(mesh_2d.get_group()[0], mesh_2d.get_group("dp"))
|
|
self.assertEqual(mesh_2d.get_group()[1], mesh_2d.get_group("tp"))
|
|
|
|
self.assertEqual(mesh_2d.get_group(0), mesh_2d.get_group("dp"))
|
|
self.assertEqual(mesh_2d.get_group(1), mesh_2d.get_group("tp"))
|
|
|
|
self.assertEqual(mesh_2d.get_group("dp"), dp_mesh.get_group())
|
|
self.assertEqual(mesh_2d.get_group("tp"), tp_mesh.get_group())
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
@skip_unless_torch_gpu
|
|
def test_get_local_rank_raises_exception(self):
|
|
# TODO: `test_get_local_rank_raises_exception` still periodically timeout on cpu
|
|
# remove `@skip_unless_torch_gpu` after the problem is fixed.
|
|
mesh_shape = (2, self.world_size // 2)
|
|
mesh_2d = init_device_mesh(
|
|
self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
|
):
|
|
local_rank = mesh_2d.get_local_rank()
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
@skip_unless_torch_gpu
|
|
def test_get_local_rank(self):
|
|
# TODO: `test_get_local_rank_raises_exception` still periodically timeout on cpu
|
|
# remove `@skip_unless_torch_gpu` after the problem is fixed.
|
|
mesh_shape = (2, self.world_size // 2)
|
|
mesh_2d = init_device_mesh(
|
|
self.device_type, mesh_shape, mesh_dim_names=("dp", "tp")
|
|
)
|
|
self.assertEqual(mesh_2d.get_local_rank("dp"), mesh_2d.get_local_rank(0))
|
|
self.assertEqual(mesh_2d.get_local_rank("tp"), mesh_2d.get_local_rank(1))
|
|
|
|
dp_mesh = mesh_2d["dp"]
|
|
tp_mesh = mesh_2d["tp"]
|
|
self.assertEqual(dp_mesh.get_local_rank(), mesh_2d.get_local_rank("dp"))
|
|
self.assertEqual(tp_mesh.get_local_rank(), mesh_2d.get_local_rank("tp"))
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_device_mesh_2d(self):
|
|
mesh_tensor = torch.arange(4).reshape(2, 2)
|
|
# construct a cuda device mesh
|
|
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
|
|
|
# check all dim groups
|
|
dim_to_subgroups = mesh.get_group()
|
|
|
|
expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]]
|
|
for dim, dim_group in enumerate(dim_to_subgroups):
|
|
self.assertTrue(dim < 2)
|
|
dim_ranks = expected_ranks_by_dim[dim]
|
|
|
|
dim_group_size = get_world_size(dim_group)
|
|
self.assertIsInstance(dim_group, ProcessGroup)
|
|
self.assertEqual(dim_group_size, 2)
|
|
global_ranks = [
|
|
get_global_rank(dim_group, i) for i in range(dim_group_size)
|
|
]
|
|
current_rank_expected_group_ranks = (
|
|
dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1]
|
|
)
|
|
self.assertEqual(global_ranks, current_rank_expected_group_ranks)
|
|
|
|
@run_with_both_funcol_impls
|
|
def test_fake_pg_device_mesh(self):
|
|
fake_store = FakeStore()
|
|
init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size)
|
|
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
|
mesh = DeviceMesh(device_type, torch.arange(self.world_size))
|
|
|
|
local_tensor = torch.randn(2, 8)
|
|
global_tensor = funcol.all_gather_tensor(
|
|
local_tensor, gather_dim=0, group=(mesh, 0)
|
|
)
|
|
self.assertEqual(global_tensor.shape, (self.world_size * 2, 8))
|
|
|
|
|
|
class DeviceMeshTestNDim(DTensorTestBase):
|
|
@property
|
|
def world_size(self):
|
|
return 8
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_device_mesh_nd(self):
|
|
# construct a cuda device mesh
|
|
mesh_tensor = torch.arange(8).reshape(2, 2, 2)
|
|
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
|
|
|
# check all dim groups
|
|
dim_to_subgroups = mesh.get_group()
|
|
|
|
for dim, dim_group in enumerate(dim_to_subgroups):
|
|
self.assertTrue(dim < mesh_tensor.ndim)
|
|
dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2)
|
|
|
|
dim_group_size = get_world_size(dim_group)
|
|
self.assertIsInstance(dim_group, ProcessGroup)
|
|
self.assertEqual(dim_group_size, 2)
|
|
global_ranks = [
|
|
get_global_rank(dim_group, i) for i in range(dim_group_size)
|
|
]
|
|
for ranks in dim_ranks:
|
|
if self.rank in ranks:
|
|
self.assertEqual(global_ranks, ranks.tolist())
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_device_mesh_hash(self):
|
|
mesh_tensor_2d = torch.arange(8).reshape(4, 2)
|
|
mesh = DeviceMesh(self.device_type, mesh_tensor_2d)
|
|
mesh2 = DeviceMesh(self.device_type, mesh_tensor_2d)
|
|
self.assertNotEqual(hash(mesh), hash(mesh2))
|
|
mesh_tensor_3d = torch.arange(8).reshape(2, 2, 2)
|
|
mesh3 = DeviceMesh(self.device_type, mesh_tensor_3d)
|
|
self.assertNotEqual(hash(mesh), hash(mesh3))
|
|
self.assertNotEqual(hash(mesh2), hash(mesh3))
|
|
|
|
|
|
class InitDeviceMeshTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self):
|
|
return 8
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_init_device_mesh(self):
|
|
mesh_shape = (2, 4)
|
|
ref_mesh = DeviceMesh(self.device_type, torch.arange(8).view(mesh_shape))
|
|
|
|
# test init_device_mesh with mesh_dim_names
|
|
mesh_dim_names = ("DP", "TP")
|
|
mesh_2d = init_device_mesh(
|
|
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
|
)
|
|
self.assertEqual(mesh_2d, ref_mesh)
|
|
self.assertEqual(mesh_2d.mesh_dim_names, mesh_dim_names)
|
|
|
|
# test init_device_mesh without mesh_dim_names
|
|
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
|
self.assertEqual(mesh_2d, ref_mesh)
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_raises_duplicate_mesh_dim_names(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Each mesh_dim_name must be unique.",
|
|
):
|
|
mesh = init_device_mesh(
|
|
self.device_type,
|
|
(2, 4),
|
|
mesh_dim_names=["dp", "dp"],
|
|
)
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_raises_mesh_shape_mesh_dim_names_mismatch(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"mesh_shape and mesh_dim_names should have same length!",
|
|
):
|
|
mesh = init_device_mesh(
|
|
self.device_type,
|
|
(8,),
|
|
mesh_dim_names=["dp", "tp"],
|
|
)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestDeviceMeshGetItem(DTensorTestBase):
|
|
@property
|
|
def world_size(self):
|
|
return 8
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_raises_invalid_mesh_dim_names(self):
|
|
error_msg = "Invalid mesh_dim_name"
|
|
# Case 1: the DeviceMesh does not have a mesh_dim_names attribute
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names."
|
|
):
|
|
mesh = init_device_mesh(self.device_type, (2, 4))
|
|
child_mesh = mesh["DP"]
|
|
|
|
child_mesh_dim_names = "PP"
|
|
with self.assertRaisesRegex(ValueError, error_msg):
|
|
mesh_dim_names = ("DP", "TP")
|
|
mesh = init_device_mesh(
|
|
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
|
|
)
|
|
child_mesh = mesh[child_mesh_dim_names]
|
|
|
|
# Case 2
|
|
child_mesh_dim_names = ["PP", "CP"]
|
|
with self.assertRaisesRegex(ValueError, error_msg):
|
|
mesh_dim_names = ("DP", "TP")
|
|
mesh = init_device_mesh(
|
|
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
|
|
)
|
|
child_mesh = mesh[child_mesh_dim_names]
|
|
|
|
# Case 3: a given child_mesh_dim_name is not a contiguous subset of the parent mesh's mesh_dim_names.
|
|
child_mesh_dim_names = ("TP", "DP")
|
|
with self.assertRaisesRegex(ValueError, error_msg):
|
|
mesh_dim_names = ("DP", "TP")
|
|
mesh = init_device_mesh(
|
|
self.device_type, (2, 4), mesh_dim_names=mesh_dim_names
|
|
)
|
|
child_mesh = mesh[child_mesh_dim_names]
|
|
|
|
# Case 3
|
|
child_mesh_dim_names = ("PP", "TP")
|
|
with self.assertRaisesRegex(ValueError, error_msg):
|
|
mesh_dim_names = ("PP", "DP", "TP")
|
|
mesh = init_device_mesh(
|
|
self.device_type, (2, 2, 2), mesh_dim_names=mesh_dim_names
|
|
)
|
|
child_mesh = mesh[child_mesh_dim_names]
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
@skip_if_lt_x_gpu(8)
|
|
def test_get_item_2d(self):
|
|
# TODO: `test_get_item_2d` still periodically timeout on cpu
|
|
# remove `@skip_if_lt_x_gpu` after the problem is fixed.
|
|
mesh_shape = (2, 4)
|
|
mesh_dim_names = ("DP", "TP")
|
|
mesh_2d = init_device_mesh(
|
|
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
|
)
|
|
|
|
tp_mesh = mesh_2d["TP"]
|
|
tp_group = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
|
tp_group_idx = self.rank // 4
|
|
self.assertEqual(tp_mesh.mesh.tolist(), tp_group[tp_group_idx])
|
|
|
|
dp_mesh = mesh_2d["DP"]
|
|
dp_group = [[0, 4], [1, 5], [2, 6], [3, 7]]
|
|
dp_group_idx = self.rank % 4
|
|
self.assertEqual(dp_mesh.mesh.tolist(), dp_group[dp_group_idx])
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_get_item_1d(self):
|
|
mesh = init_device_mesh(self.device_type, (8,), mesh_dim_names=("dp",))
|
|
# Make sure slicing out 1D mesh from a 1D mesh works.
|
|
# We are just dummy return without the parent mesh here.
|
|
dp_mesh = mesh["dp"]
|
|
self.assertEqual(dp_mesh, mesh)
|
|
|
|
with self.assertRaisesRegex(ValueError, "Invalid mesh_dim_name"):
|
|
dp_mesh = mesh["dim0"]
|
|
|
|
@with_comms
|
|
@skip_if_lt_x_gpu(8)
|
|
def test_get_item_3d(self):
|
|
# TODO: `test_get_item_3d` still periodically timeout on cpu
|
|
# remove `@skip_if_lt_x_gpu` after the problem is fixed.
|
|
mesh_shape = (2, 2, 2)
|
|
mesh_dim_names = ("Replicate", "Shard", "TP")
|
|
mesh_3d = init_device_mesh(
|
|
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
|
)
|
|
|
|
tp_group = [[0, 1], [2, 3], [4, 5], [6, 7]]
|
|
tp_group_idx = int(self.rank / 2)
|
|
self.assertEqual(mesh_3d["TP"].mesh.tolist(), tp_group[tp_group_idx])
|
|
|
|
shard_group = [[0, 2], [1, 3], [4, 6], [5, 7]]
|
|
shard_group_idx = self.rank % 2 + self.rank // 4 * 2
|
|
self.assertEqual(mesh_3d["Shard"].mesh.tolist(), shard_group[shard_group_idx])
|
|
|
|
replicate_group = [[0, 4], [1, 5], [2, 6], [3, 7]]
|
|
replicate_group_idx = self.rank % 4
|
|
self.assertEqual(
|
|
mesh_3d["Replicate"].mesh.tolist(), replicate_group[replicate_group_idx]
|
|
)
|
|
|
|
# We support both UX for nD slicing.
|
|
# mesh_3d[["Replicate", "Shard"]] or mesh_3d["Replicate", "Shard"]
|
|
hsdp_mesh_1 = mesh_3d[["Replicate", "Shard"]]
|
|
hsdp_mesh_2 = mesh_3d["Replicate", "Shard"]
|
|
hsdp_group = [[[0, 2], [4, 6]], [[1, 3], [5, 7]]]
|
|
hsdp_group_idx = self.rank % 2
|
|
self.assertEqual(hsdp_mesh_1.mesh.tolist(), hsdp_group[hsdp_group_idx])
|
|
self.assertEqual(hsdp_mesh_2.mesh.tolist(), hsdp_group[hsdp_group_idx])
|
|
self.assertEqual(hsdp_mesh_1, hsdp_mesh_2)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestMeshEnv(DTensorTestBase):
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
@skip_unless_torch_gpu
|
|
def test_get_parent_mesh(self):
|
|
mesh_shape = (2, self.world_size // 2)
|
|
mesh_dim_names = ("DP", "TP")
|
|
mesh_2d = init_device_mesh(
|
|
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
|
)
|
|
|
|
self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["DP"]), mesh_2d)
|
|
self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["TP"]), mesh_2d)
|
|
|
|
mesh_0_2 = DeviceMesh(self.device_type, [0, 2])
|
|
mesh_1_3 = DeviceMesh(self.device_type, [1, 3])
|
|
|
|
self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["DP"]), mesh_2d)
|
|
self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["TP"]), mesh_2d)
|
|
self.assertEqual(_mesh_resources.get_parent_mesh(mesh_0_2), None)
|
|
self.assertEqual(_mesh_resources.get_parent_mesh(mesh_1_3), None)
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
@skip_unless_torch_gpu
|
|
def test_get_parent_mesh_dim_exist(self):
|
|
mesh_shape = (2, self.world_size // 2)
|
|
mesh_dim_names = ("DP", "TP")
|
|
mesh_2d = init_device_mesh(
|
|
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
|
)
|
|
|
|
self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh_2d["DP"]), 0)
|
|
self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh_2d["TP"]), 1)
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
@skip_unless_torch_gpu
|
|
def test_get_parent_mesh_dim_not_exist(self):
|
|
mesh_shape = (self.world_size,)
|
|
mesh = init_device_mesh(self.device_type, mesh_shape)
|
|
|
|
self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh), None)
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
@skip_unless_torch_gpu
|
|
def test_get_mesh_dim_by_name(self):
|
|
mesh_shape = (2, self.world_size // 2)
|
|
mesh_dim_names = ("DP", "TP")
|
|
mesh_2d = init_device_mesh(
|
|
self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
|
|
)
|
|
|
|
self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "DP"), 0)
|
|
self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "TP"), 1)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class DeviceMeshCollectiveTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self):
|
|
return 8
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_broadcast_1d(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
|
|
mesh_broadcast(local_tensor, mesh, mesh_dim=0)
|
|
self.assertEqual(local_tensor, torch.zeros(3, 3))
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_scatter_1d(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
scatter_tensor_shape = [3, 3, 3]
|
|
for scatter_dim in range(len(scatter_tensor_shape)):
|
|
shard_placement = Shard(scatter_dim)
|
|
scatter_tensor_shape[scatter_dim] *= self.world_size
|
|
# make the random seed same across rank
|
|
torch.manual_seed(0)
|
|
global_tensor = torch.randn(scatter_tensor_shape, device=self.device_type)
|
|
splitted_list, _ = shard_placement._split_tensor(
|
|
global_tensor, mesh.size(), with_padding=True, contiguous=True
|
|
)
|
|
recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()])
|
|
# scatter on dim > 0 would generate non-contiguous tensor, verify that works
|
|
mesh_scatter(recv_tensor, splitted_list, mesh, mesh_dim=0)
|
|
self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()])
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_scatter_uneven(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
my_rank = device_mesh.get_rank()
|
|
tensor_to_split = torch.randn(
|
|
device_mesh.size() + 3, device_mesh.size() + 1, device=self.device_type
|
|
)
|
|
|
|
for shard_dim in range(tensor_to_split.ndim):
|
|
shard_placement = Shard(shard_dim)
|
|
|
|
tensor_to_scatter = tensor_to_split.clone()
|
|
tensor_splitted_list = list(
|
|
torch.chunk(tensor_to_split, self.world_size, dim=shard_dim)
|
|
)
|
|
for _ in range(self.world_size - len(tensor_splitted_list)):
|
|
tensor_splitted_list.append(torch.tensor([], device=self.device_type))
|
|
|
|
padded_tensor_list, pad_sizes = shard_placement._split_tensor(
|
|
tensor_to_scatter,
|
|
device_mesh.size(),
|
|
with_padding=True,
|
|
contiguous=True,
|
|
)
|
|
|
|
scattered_tensor = torch.empty_like(padded_tensor_list[my_rank])
|
|
mesh_scatter(scattered_tensor, padded_tensor_list, device_mesh, mesh_dim=0)
|
|
|
|
if pad_sizes[my_rank] != 0:
|
|
scattered_tensor = shard_placement._unpad_tensor(
|
|
scattered_tensor, pad_sizes[my_rank]
|
|
)
|
|
|
|
if scattered_tensor.numel() == 0:
|
|
# We need to check numel() instead of size if a tensor is ([]) after unpadding,
|
|
# since the size could be ([0, 8]) after unpadding.
|
|
self.assertEqual(
|
|
scattered_tensor.numel(), tensor_splitted_list[my_rank].numel()
|
|
)
|
|
else:
|
|
self.assertEqual(
|
|
scattered_tensor.size(), tensor_splitted_list[my_rank].size()
|
|
)
|
|
self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank])
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_all_gather_uneven(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
my_rank = device_mesh.get_rank()
|
|
tensor_to_split = torch.ones(
|
|
device_mesh.size() + 3,
|
|
device_mesh.size() + 1,
|
|
device=self.device_type,
|
|
)
|
|
|
|
for shard_dim in range(tensor_to_split.ndim):
|
|
shard_placement = Shard(shard_dim)
|
|
tensor_padded_list, pad_sizes = shard_placement._split_tensor(
|
|
tensor_to_split,
|
|
device_mesh.size(),
|
|
with_padding=True,
|
|
contiguous=True,
|
|
)
|
|
local_tensor = tensor_padded_list[my_rank]
|
|
big_tensor = funcol.all_gather_tensor(
|
|
local_tensor, gather_dim=shard_dim, group=(device_mesh, 0)
|
|
)
|
|
big_tensor_chunks = list(
|
|
torch.chunk(big_tensor, device_mesh.size(), dim=shard_dim)
|
|
)
|
|
unpadded_list = [
|
|
shard_placement._unpad_tensor(big_tensor_chunks[i], pad_sizes[i])
|
|
if pad_sizes[i] > 0
|
|
else big_tensor_chunks[i]
|
|
for i, big_tensor in enumerate(big_tensor_chunks)
|
|
]
|
|
all_gathered_tensor = torch.cat(unpadded_list, dim=shard_dim)
|
|
|
|
self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size())
|
|
self.assertEqual(all_gathered_tensor, tensor_to_split)
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_reduce_scatter_contiguous(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
my_rank = device_mesh.get_rank()
|
|
|
|
# Init the tensor
|
|
step = self.world_size * 2
|
|
total_elem = step**2
|
|
tensor = torch.arange(0, total_elem).view(step, -1).to(device=self.device_type)
|
|
tensor = tensor * (my_rank + 1)
|
|
|
|
# Get non-contiguous tensor by slicing
|
|
tensor_to_reduce = tensor[::2, :2]
|
|
tensor_contiguous = tensor_to_reduce.clone().contiguous()
|
|
|
|
# Partial to Shard to trigger reduce_scatter
|
|
tensor_to_reduce = DTensor.from_local(
|
|
tensor_to_reduce, device_mesh, [_Partial()]
|
|
)
|
|
tensor_contiguous = DTensor.from_local(
|
|
tensor_contiguous, device_mesh, [_Partial()]
|
|
)
|
|
new_tensor = tensor_to_reduce.redistribute(device_mesh, [Shard(0)])
|
|
new_tensor_contiguous = tensor_contiguous.redistribute(device_mesh, [Shard(0)])
|
|
|
|
# The output for contiguous and non-contiguous tensors of the same value
|
|
# should return the same reducescatter value.
|
|
new_tensor_local = new_tensor._local_tensor
|
|
new_tensor_contiguous_local = new_tensor_contiguous._local_tensor
|
|
self.assertEqual(new_tensor_local, new_tensor_contiguous_local)
|
|
self.assertEqual(list(new_tensor_local.size()), [1, 2])
|
|
|
|
# Check the reduce numerical value
|
|
sum_base = (1 + self.world_size) * self.world_size / 2
|
|
first_elem = my_rank * sum_base * step * 2
|
|
expected_tensor = torch.tensor(
|
|
[[first_elem, first_elem + sum_base]],
|
|
dtype=new_tensor_local.dtype,
|
|
device=self.device_type,
|
|
)
|
|
self.assertEqual(new_tensor_local, expected_tensor)
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_reduce_scatter_uneven(self):
|
|
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
|
my_rank = device_mesh.get_rank()
|
|
tensor_to_split = (
|
|
torch.ones(
|
|
device_mesh.size() + 3,
|
|
device_mesh.size() + 1,
|
|
device=self.device_type,
|
|
)
|
|
* self.rank
|
|
)
|
|
|
|
for shard_dim in range(tensor_to_split.ndim):
|
|
shard_placement = Shard(shard_dim)
|
|
tensor_to_scatter = tensor_to_split.clone()
|
|
|
|
tensor_splitted_list = list(
|
|
torch.chunk(tensor_to_split, self.world_size, dim=shard_dim)
|
|
)
|
|
for _ in range(self.world_size - len(tensor_splitted_list)):
|
|
tensor_splitted_list.append(torch.tensor([], device=self.device_type))
|
|
|
|
padded_tensor_list, pad_sizes = shard_placement._split_tensor(
|
|
tensor_to_scatter,
|
|
device_mesh.size(),
|
|
with_padding=True,
|
|
contiguous=True,
|
|
)
|
|
|
|
tensor_to_reduce = torch.cat(padded_tensor_list, shard_dim)
|
|
|
|
res_num = ((0 + self.world_size - 1) * self.world_size) / 2
|
|
|
|
scattered_tensor = funcol.reduce_scatter_tensor(
|
|
tensor_to_reduce,
|
|
reduceOp="sum",
|
|
scatter_dim=shard_dim,
|
|
group=(device_mesh, 0),
|
|
)
|
|
|
|
# unpad scattered_tensor
|
|
if pad_sizes[my_rank] > 0:
|
|
scattered_tensor = shard_placement._unpad_tensor(
|
|
scattered_tensor, pad_sizes[my_rank]
|
|
)
|
|
|
|
if scattered_tensor.numel() == 0:
|
|
# We need to check numel() instead of size if a tensor is ([]) after unpadding,
|
|
# since the size could be ([0, 8]) after unpadding.
|
|
self.assertEqual(
|
|
scattered_tensor.numel(), tensor_splitted_list[my_rank].numel()
|
|
)
|
|
else:
|
|
self.assertEqual(
|
|
scattered_tensor.size(), tensor_splitted_list[my_rank].size()
|
|
)
|
|
self.assertEqual(
|
|
scattered_tensor,
|
|
torch.ones_like(tensor_splitted_list[my_rank]) * res_num,
|
|
)
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_broadcast_nd(self):
|
|
mesh_tensor = torch.arange(8).reshape(2, 2, 2)
|
|
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
|
local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
|
|
|
|
# check all dim groups
|
|
dim_to_subgroups = mesh.get_group()
|
|
for dim, dim_group in enumerate(dim_to_subgroups):
|
|
dim_group_size = get_world_size(dim_group)
|
|
global_ranks = [
|
|
get_global_rank(dim_group, i) for i in range(dim_group_size)
|
|
]
|
|
cloned_local_tensor = local_tensor.clone()
|
|
mesh_broadcast(cloned_local_tensor, mesh, mesh_dim=dim)
|
|
res_num = global_ranks[0]
|
|
self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num)
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_scatter_nd(self):
|
|
mesh_tensor = torch.arange(8).reshape(2, 2, 2)
|
|
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
|
|
|
# check all dim groups
|
|
dim_to_subgroups = mesh.get_group()
|
|
for dim, dim_group in enumerate(dim_to_subgroups):
|
|
dim_group_size = get_world_size(dim_group)
|
|
global_ranks = [
|
|
get_global_rank(dim_group, i) for i in range(dim_group_size)
|
|
]
|
|
scattered_tensors = [
|
|
torch.ones(3, 3, device=self.device_type) * global_rank
|
|
for global_rank in global_ranks
|
|
]
|
|
received_tensor = torch.empty_like(
|
|
scattered_tensors[mesh.get_coordinate()[dim]]
|
|
)
|
|
mesh_scatter(received_tensor, scattered_tensors, mesh, mesh_dim=dim)
|
|
self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank)
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_all_to_all_1d(self):
|
|
# transpose on a 2D tensor distributed over N nodes:
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
tensor_shape = [3, 3]
|
|
input_tensor_list = [
|
|
torch.ones(*tensor_shape, device=self.device_type)
|
|
* (rank + self.rank * self.world_size)
|
|
for rank in range(self.world_size)
|
|
]
|
|
expected_tensor_list = [
|
|
torch.ones(tensor_shape, device=self.device_type)
|
|
* (self.rank + rank * self.world_size) # i.e. transpose
|
|
for rank in range(self.world_size)
|
|
]
|
|
for scatter_dim in range(len(tensor_shape)):
|
|
output_tensor_list = [
|
|
torch.empty_like(input_tensor_list[idx])
|
|
for idx in range(len(input_tensor_list))
|
|
]
|
|
# scatter on dim > 0 would generate non-contiguous tensor, verify that works
|
|
mesh_all_to_all(output_tensor_list, input_tensor_list, mesh, mesh_dim=0)
|
|
output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)
|
|
expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim)
|
|
|
|
self.assertEqual(output_tensor, expected_tensor)
|
|
|
|
@with_comms
|
|
@run_with_both_funcol_impls
|
|
def test_all_to_all_nd(self):
|
|
mesh_tensor = torch.arange(8).reshape(2, 2, 2)
|
|
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
|
tensor_shape = [3, 3, 3]
|
|
# check all dim groups
|
|
dim_to_subgroups = mesh.get_group()
|
|
for dim, dim_group in enumerate(dim_to_subgroups):
|
|
my_coordinate = mesh.get_coordinate()[dim]
|
|
dim_group_size = get_world_size(dim_group)
|
|
global_ranks = [
|
|
get_global_rank(dim_group, i) for i in range(dim_group_size)
|
|
]
|
|
input_tensor_list = [
|
|
torch.ones(*tensor_shape, device=self.device_type)
|
|
* (i + self.rank * dim_group_size)
|
|
for i in range(dim_group_size)
|
|
]
|
|
expected_tensor_list = [
|
|
torch.ones(*tensor_shape, device=self.device_type)
|
|
* (my_coordinate + global_rank * dim_group_size) # i.e. transpose
|
|
for global_rank in global_ranks
|
|
]
|
|
for scatter_dim in range(len(tensor_shape)):
|
|
# input_tensor = torch.cat(input_tensor_list, dim=scatter_dim)
|
|
output_tensor_list = [
|
|
torch.empty_like(input_tensor_list[idx])
|
|
for idx in range(len(input_tensor_list))
|
|
]
|
|
# scatter on dim > 0 would generate non-contiguous tensor, verify that works
|
|
mesh_all_to_all(
|
|
output_tensor_list, input_tensor_list, mesh, mesh_dim=dim
|
|
)
|
|
output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)
|
|
expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim)
|
|
self.assertEqual(output_tensor, expected_tensor)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|