mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable local tensor mode DTensor tests for the optimizers, op strategy, matrix ops, math ops, init ops, experimental ops, embedding ops, dynamic, convolution ops, main api. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166105 Approved by: https://github.com/ezyang
275 lines
10 KiB
Python
275 lines
10 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import torch
|
|
from torch.distributed._local_tensor import maybe_run_for_local_tensor
|
|
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard, zeros
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
create_local_tensor_test_class,
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
class DTensorInitOpsTest(DTensorTestBase):
|
|
def _run_init_op(self, init_op, *args, **kwargs):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(0)]
|
|
input_size = (8, 4)
|
|
input_tensor = torch.randn(*input_size, device=self.device_type)
|
|
dtensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
|
|
local_tensor_clone = torch.clone(input_tensor)
|
|
torch.manual_seed(self.rank)
|
|
local_tensor_clone = init_op(local_tensor_clone, *args, **kwargs)
|
|
torch.manual_seed(self.rank)
|
|
dtensor = init_op(dtensor, *args, **kwargs)
|
|
self.assertEqual(local_tensor_clone, dtensor.to_local())
|
|
|
|
@with_comms
|
|
def test_init_ops(self):
|
|
# NOTE: random init tests are moved to test_random_ops.py
|
|
self._run_init_op(torch.nn.init.constant_, 2.4)
|
|
|
|
|
|
class DTensorConstructorTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self):
|
|
return 4
|
|
|
|
def _run_init_op(self, init_op, dist_init_op, eq_op, *args, **kwargs):
|
|
# 1d mesh test
|
|
device_mesh = self.build_device_mesh()
|
|
placements_list = [[Shard(0)], [Shard(1)], [Shard(2)], [Replicate()]]
|
|
|
|
# even sharding
|
|
tensor_size = [4, 8, 12]
|
|
for placements in placements_list:
|
|
local_tensor_size = tensor_size.copy()
|
|
if isinstance(placements[0], Shard):
|
|
shard_dim = placements[0].dim
|
|
local_tensor_size[shard_dim] //= self.world_size
|
|
|
|
dist_tensor = dist_init_op(
|
|
tensor_size,
|
|
*args,
|
|
**kwargs,
|
|
device_mesh=device_mesh,
|
|
placements=placements,
|
|
)
|
|
ones_expected = init_op(local_tensor_size, *args, **kwargs)
|
|
eq_op(ones_expected, dist_tensor.to_local())
|
|
|
|
# uneven sharding
|
|
tensor_size = [5, 10, 15]
|
|
for placements in placements_list:
|
|
dist_tensor = dist_init_op(
|
|
tensor_size,
|
|
*args,
|
|
**kwargs,
|
|
device_mesh=device_mesh,
|
|
placements=placements,
|
|
)
|
|
if isinstance(placements[0], Shard):
|
|
shard_dim = placements[0].dim
|
|
exp_tensor_list = list(
|
|
torch.chunk(
|
|
init_op(tensor_size, *args, **kwargs),
|
|
self.world_size,
|
|
dim=shard_dim,
|
|
)
|
|
)
|
|
|
|
@maybe_run_for_local_tensor
|
|
def check_per_rank_chunk(rank, local_tensor):
|
|
if rank < len(exp_tensor_list):
|
|
eq_op(exp_tensor_list[rank], local_tensor)
|
|
|
|
check_per_rank_chunk(self.rank, dist_tensor.to_local())
|
|
else:
|
|
exp_tensor = init_op(tensor_size, *args, **kwargs)
|
|
eq_op(exp_tensor, dist_tensor.to_local())
|
|
|
|
# empty shape
|
|
local_tensor = dist_init_op(
|
|
[], *args, **kwargs, device_mesh=device_mesh, placements=[Replicate()]
|
|
).to_local()
|
|
expected_tensor = init_op([], *args, **kwargs)
|
|
eq_op(expected_tensor, local_tensor)
|
|
|
|
@with_comms
|
|
def test_ones(self):
|
|
self._run_init_op(
|
|
torch.ones,
|
|
torch.distributed.tensor.ones,
|
|
self.assertEqual,
|
|
requires_grad=True,
|
|
)
|
|
|
|
@with_comms
|
|
def test_empty(self):
|
|
self._run_init_op(
|
|
torch.empty,
|
|
torch.distributed.tensor.empty,
|
|
lambda x, y: (x.shape == y.shape)
|
|
and (x.dtype == y.dtype)
|
|
and (x.layout == y.layout),
|
|
requires_grad=True,
|
|
)
|
|
|
|
@with_comms
|
|
def test_full(self):
|
|
self._run_init_op(
|
|
torch.full,
|
|
torch.distributed.tensor.full,
|
|
self.assertEqual,
|
|
123.4,
|
|
requires_grad=True,
|
|
)
|
|
|
|
@with_comms
|
|
def test_zeros(self):
|
|
self._run_init_op(
|
|
torch.zeros,
|
|
torch.distributed.tensor.zeros,
|
|
self.assertEqual,
|
|
requires_grad=True,
|
|
)
|
|
|
|
@with_comms
|
|
def test_zeros_full_mesh(self):
|
|
# construct a gpu device 1d mesh
|
|
mesh = self.build_device_mesh()
|
|
placements = [Shard(0)]
|
|
size = [32, 3]
|
|
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
|
|
self.assertEqual(dist_tensor.size(), torch.Size(size))
|
|
local_tensor = dist_tensor.to_local()
|
|
self.assertEqual(local_tensor.size(), torch.Size([8, 3]))
|
|
|
|
local_tensor = torch.zeros(8, 3)
|
|
self.assertEqual(dist_tensor.to_local(), local_tensor)
|
|
|
|
self.assertEqual(dist_tensor.device.type, self.device_type)
|
|
|
|
# 1d sharded unevenly
|
|
size = [31, 3]
|
|
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
|
|
self.assertEqual(dist_tensor.size(), torch.Size(size))
|
|
local_tensor = dist_tensor.to_local()
|
|
|
|
@maybe_run_for_local_tensor
|
|
def check_per_rank_tensors(rank, local_tensor):
|
|
if rank <= 2:
|
|
self.assertEqual(local_tensor.size(), torch.Size([8, 3]))
|
|
self.assertEqual(torch.zeros(8, 3), local_tensor)
|
|
else:
|
|
self.assertEqual(local_tensor.size(), torch.Size([7, 3]))
|
|
self.assertEqual(torch.zeros(7, 3), local_tensor)
|
|
|
|
check_per_rank_tensors(self.rank, local_tensor)
|
|
|
|
# construct a gpu device mesh with 2d: shard, replicate
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
|
|
placements = [Shard(0), Replicate()]
|
|
size = [32, 4]
|
|
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
|
|
|
|
self.assertEqual(dist_tensor.size(), torch.Size(size))
|
|
local_tensor = dist_tensor.to_local()
|
|
self.assertEqual(local_tensor.size(), torch.Size([16, 4]))
|
|
self.assertEqual(local_tensor, torch.zeros([16, 4]))
|
|
|
|
# construct a gpu device mesh with 2d: shard, shard
|
|
placements = [Shard(0), Shard(1)]
|
|
size = [32, 4]
|
|
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
|
|
|
|
self.assertEqual(dist_tensor.size(), torch.Size(size))
|
|
local_tensor = dist_tensor.to_local()
|
|
self.assertEqual(local_tensor.size(), torch.Size([16, 2]))
|
|
self.assertEqual(local_tensor, torch.zeros([16, 2]))
|
|
|
|
# 2d sharded unevenly
|
|
placements = [Shard(0), Shard(1)]
|
|
size = [31, 3]
|
|
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
|
|
|
|
self.assertEqual(dist_tensor.size(), torch.Size(size))
|
|
local_tensor = dist_tensor.to_local()
|
|
if self.rank == 0:
|
|
self.assertEqual(local_tensor, torch.zeros([16, 2]))
|
|
elif self.rank == 1:
|
|
self.assertEqual(local_tensor, torch.zeros([16, 1]))
|
|
elif self.rank == 2:
|
|
self.assertEqual(local_tensor, torch.zeros([15, 2]))
|
|
elif self.rank == 3:
|
|
self.assertEqual(local_tensor, torch.zeros([15, 1]))
|
|
|
|
@with_comms
|
|
def test_zeros_submesh(self):
|
|
# default world_size is 4
|
|
# construct a gpu device 1d mesh, with no sub pg initialized
|
|
sub_mesh_list = [0, 3]
|
|
mesh = DeviceMesh(self.device_type, sub_mesh_list)
|
|
placements = [Shard(0)]
|
|
size = [32, 3]
|
|
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
|
|
self.assertEqual(dist_tensor.size(), torch.Size(size))
|
|
local_tensor = dist_tensor.to_local()
|
|
|
|
if self.rank in sub_mesh_list:
|
|
self.assertEqual(local_tensor.size(), torch.Size([16, 3]))
|
|
self.assertEqual(local_tensor, torch.zeros([16, 3]))
|
|
else:
|
|
self.assertEqual(local_tensor.size(), torch.Size([0]))
|
|
self.assertEqual(local_tensor, torch.zeros(0))
|
|
|
|
# construct a gpu device 1d mesh: unevenly, with subpg initialized
|
|
sub_mesh_list = [0, 1, 3]
|
|
mesh = DeviceMesh(self.device_type, sub_mesh_list)
|
|
placements = [Shard(0)]
|
|
size = [32, 3]
|
|
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
|
|
self.assertEqual(dist_tensor.size(), torch.Size(size))
|
|
local_tensor = dist_tensor.to_local()
|
|
|
|
if self.rank in sub_mesh_list:
|
|
if self.rank != 3:
|
|
self.assertEqual(local_tensor.size(), torch.Size([11, 3]))
|
|
self.assertEqual(local_tensor, torch.zeros([11, 3]))
|
|
else:
|
|
self.assertEqual(local_tensor.size(), torch.Size([10, 3]))
|
|
self.assertEqual(local_tensor, torch.zeros([10, 3]))
|
|
else:
|
|
self.assertEqual(local_tensor.size(), torch.Size([0]))
|
|
self.assertEqual(local_tensor, torch.tensor([]))
|
|
|
|
# construct a gpu device 2d mesh, with no subpg initialized
|
|
sub_mesh_list = [[0], [3]]
|
|
mesh = DeviceMesh(self.device_type, sub_mesh_list)
|
|
placements = [Shard(0), Shard(1)]
|
|
size = [32, 3]
|
|
dist_tensor = zeros(size, device_mesh=mesh, placements=placements)
|
|
self.assertEqual(dist_tensor.size(), torch.Size(size))
|
|
local_tensor = dist_tensor.to_local()
|
|
|
|
if self.rank in [0, 3]:
|
|
self.assertEqual(local_tensor.size(), torch.Size([16, 3]))
|
|
self.assertEqual(local_tensor, torch.zeros([16, 3]))
|
|
else:
|
|
self.assertEqual(local_tensor.size(), torch.Size([0]))
|
|
self.assertEqual(local_tensor, torch.tensor([]))
|
|
|
|
|
|
DTensorConstructorTestWithLocalTensor = create_local_tensor_test_class(
|
|
DTensorConstructorTest,
|
|
skipped_tests=[
|
|
# Non-contigous sub-meshes are not supported
|
|
"test_zeros_submesh",
|
|
],
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|