pytorch/test/distributed/_tensor/test_utils.py
Ke Sang 6c061e5145 [DTensor][Shampoo] add _tenso.zero function (#95863)
Summary:
implement zeros function inside DTensor API
- user specify the zeros tensor shape, and the function will create local zero tensor given the placement information

Test Plan:
{F889157756} - unit test for util function for compute_local_tensor_size
- unit test for _tensor.zeros

Reviewed By: wanchaol

Differential Revision: D43630718

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95863
Approved by: https://github.com/wanchaol
2023-03-03 19:36:44 +00:00

75 lines
2.5 KiB
Python

# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.distributed._tensor.utils import compute_local_tensor_size
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
class UtilTest(DTensorTestBase):
@property
def world_size(self):
return 8
@with_comms
def test_compute_local_tensor_size_2d(self):
# mesh: 4 * 2
mesh_tensor = torch.arange(self.world_size).reshape(4, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
size = torch.Size([8, 6])
# replicate, replicate
placements1 = [Replicate(), Replicate()]
local_size1 = compute_local_tensor_size(size, mesh, placements1)
self.assertEqual(local_size1, torch.Size([8, 6]))
# replicate, shard
placements2 = [Replicate(), Shard(0)]
local_size2 = compute_local_tensor_size(size, mesh, placements2)
self.assertEqual(local_size2, torch.Size([4, 6]))
# shard, shard
placements3 = [Shard(0), Shard(1)]
local_size3 = compute_local_tensor_size(size, mesh, placements3)
self.assertEqual(local_size3, torch.Size([2, 3]))
@with_comms
def test_compute_local_tensor_size_2d_not_evenly(self):
# mesh: 4 * 2
mesh_tensor = torch.arange(self.world_size).reshape(4, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
size = torch.Size([7, 7])
rank_coordinates = mesh.get_coordinate()
# replicate, shard
placements2 = [Replicate(), Shard(0)]
local_size2 = compute_local_tensor_size(size, mesh, placements2)
if rank_coordinates[1] < 1:
self.assertEqual(local_size2, torch.Size([4, 7]))
else:
self.assertEqual(local_size2, torch.Size([3, 7]))
# shard, shard
placements3 = [Shard(0), Shard(1)]
local_size3 = compute_local_tensor_size(size, mesh, placements3)
# first dim
if rank_coordinates[0] < 3:
self.assertEqual(local_size3[0], 2)
else:
self.assertEqual(local_size3[0], 1)
# second dim
if rank_coordinates[1] < 1:
self.assertEqual(local_size3[1], 4)
else:
self.assertEqual(local_size3[1], 3)
if __name__ == "__main__":
run_tests()