pytorch/test/distributed/test_multi_threaded_pg.py

77 lines
2.6 KiB
Python

# Owner(s): ["oncall: distributed"]
import sys
import torch
import torch.distributed as dist
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import (
spawn_threads_and_init_comms,
MultiThreadedTestCase
)
from torch.testing._internal.common_utils import TestCase, run_tests
DEFAULT_WORLD_SIZE = 4
class TestCollectivesWithWrapper(TestCase):
@spawn_threads_and_init_comms(world_size=4)
def test_broadcast_object_list(self):
val = 99 if dist.get_rank() == 0 else None
object_list = [val] * dist.get_world_size()
dist.broadcast_object_list(object_list=object_list)
self.assertEqual(99, object_list[0])
class TestCollectivesWithBaseClass(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def test_allgather(self):
input_tensor = torch.ones(3, 3) * dist.get_rank()
output_tensors = [torch.empty_like(input_tensor) for _ in range(self.world_size)]
dist.all_gather(output_tensors, input_tensor)
for rank, out_tensor in enumerate(output_tensors):
self.assertEqual(out_tensor, torch.ones(3, 3) * rank)
def test_broadcast(self):
input_tensor = torch.ones(3, 3) * dist.get_rank()
for rank in range(self.world_size):
cloned_input = input_tensor.clone()
dist.broadcast(cloned_input, src=rank)
self.assertEqual(cloned_input, torch.ones(3, 3) * rank)
def test_scatter(self):
if dist.get_rank() == 0:
scatter_list = [torch.ones(3, 3) * rank for rank in range(self.world_size)]
else:
scatter_list = None
output_tensor = torch.empty(3, 3)
dist.scatter(output_tensor, scatter_list)
self.assertEqual(output_tensor, torch.ones(3, 3) * dist.get_rank())
def test_reduce_scatter(self):
to_reduce_scatter = [torch.ones(3, 3) * rank for rank in range(self.world_size)]
output_tensor = torch.empty(3, 3)
dist.reduce_scatter(output_tensor, to_reduce_scatter)
expected_tensor = torch.ones(3, 3) * dist.get_rank() * self.world_size
self.assertEqual(output_tensor, expected_tensor)
def test_broadcast_object_list(self):
val = 99 if dist.get_rank() == 0 else None
object_list = [val] * dist.get_world_size()
print(f"{dist.get_rank()} -> {dist.get_world_size()}")
dist.broadcast_object_list(object_list=object_list)
self.assertEqual(99, object_list[0])
if __name__ == "__main__":
run_tests()