mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
implement zeros function inside DTensor API
- user specify the zeros tensor shape, and the function will create local zero tensor given the placement information
Test Plan:
{F889157756} - unit test for util function for compute_local_tensor_size
- unit test for _tensor.zeros
Reviewed By: wanchaol
Differential Revision: D43630718
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95863
Approved by: https://github.com/wanchaol
75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import torch
|
|
from torch.distributed._tensor.device_mesh import DeviceMesh
|
|
from torch.distributed._tensor.placement_types import Replicate, Shard
|
|
from torch.distributed._tensor.utils import compute_local_tensor_size
|
|
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
class UtilTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self):
|
|
return 8
|
|
|
|
@with_comms
|
|
def test_compute_local_tensor_size_2d(self):
|
|
# mesh: 4 * 2
|
|
mesh_tensor = torch.arange(self.world_size).reshape(4, 2)
|
|
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
|
size = torch.Size([8, 6])
|
|
|
|
# replicate, replicate
|
|
placements1 = [Replicate(), Replicate()]
|
|
local_size1 = compute_local_tensor_size(size, mesh, placements1)
|
|
self.assertEqual(local_size1, torch.Size([8, 6]))
|
|
|
|
# replicate, shard
|
|
placements2 = [Replicate(), Shard(0)]
|
|
local_size2 = compute_local_tensor_size(size, mesh, placements2)
|
|
self.assertEqual(local_size2, torch.Size([4, 6]))
|
|
|
|
# shard, shard
|
|
placements3 = [Shard(0), Shard(1)]
|
|
local_size3 = compute_local_tensor_size(size, mesh, placements3)
|
|
self.assertEqual(local_size3, torch.Size([2, 3]))
|
|
|
|
@with_comms
|
|
def test_compute_local_tensor_size_2d_not_evenly(self):
|
|
# mesh: 4 * 2
|
|
mesh_tensor = torch.arange(self.world_size).reshape(4, 2)
|
|
mesh = DeviceMesh(self.device_type, mesh_tensor)
|
|
size = torch.Size([7, 7])
|
|
rank_coordinates = mesh.get_coordinate()
|
|
|
|
# replicate, shard
|
|
placements2 = [Replicate(), Shard(0)]
|
|
local_size2 = compute_local_tensor_size(size, mesh, placements2)
|
|
if rank_coordinates[1] < 1:
|
|
self.assertEqual(local_size2, torch.Size([4, 7]))
|
|
else:
|
|
self.assertEqual(local_size2, torch.Size([3, 7]))
|
|
|
|
# shard, shard
|
|
placements3 = [Shard(0), Shard(1)]
|
|
local_size3 = compute_local_tensor_size(size, mesh, placements3)
|
|
# first dim
|
|
if rank_coordinates[0] < 3:
|
|
self.assertEqual(local_size3[0], 2)
|
|
else:
|
|
self.assertEqual(local_size3[0], 1)
|
|
# second dim
|
|
if rank_coordinates[1] < 1:
|
|
self.assertEqual(local_size3[1], 4)
|
|
else:
|
|
self.assertEqual(local_size3[1], 3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|