mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Forward fix for intermittent failures after landing of #103033 (resolves issue #103195) After #103033 , some tests are no longer applicable. Cc @huydhn Pull Request resolved: https://github.com/pytorch/pytorch/pull/103223 Approved by: https://github.com/huydhn, https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/ZainRizvi
166 lines
4.9 KiB
Python
166 lines
4.9 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import os
|
|
import sys
|
|
from functools import wraps, partial
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcessTestCase,
|
|
TEST_SKIPS
|
|
)
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
run_tests,
|
|
TEST_WITH_DEV_DBG_ASAN,
|
|
)
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
|
|
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
|
|
|
|
def with_comms(func=None):
|
|
if func is None:
|
|
return partial(
|
|
with_comms,
|
|
)
|
|
|
|
@wraps(func)
|
|
def wrapper(self, *args, **kwargs):
|
|
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
|
|
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
|
self.dist_init()
|
|
func(self)
|
|
self.destroy_comms()
|
|
return wrapper
|
|
|
|
class TestObjectCollectives(MultiProcessTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
os.environ["WORLD_SIZE"] = str(self.world_size)
|
|
os.environ["BACKEND"] = BACKEND
|
|
self._spawn_processes()
|
|
|
|
@property
|
|
def device(self):
|
|
return torch.device(self.rank) if BACKEND == dist.Backend.NCCL \
|
|
else torch.device("cpu")
|
|
|
|
@property
|
|
def world_size(self):
|
|
return WORLD_SIZE
|
|
|
|
@property
|
|
def process_group(self):
|
|
return dist.group.WORLD
|
|
|
|
def destroy_comms(self):
|
|
# Wait for all ranks to reach here before starting shutdown.
|
|
dist.barrier()
|
|
dist.destroy_process_group()
|
|
|
|
def dist_init(self):
|
|
dist.init_process_group(
|
|
backend=BACKEND,
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
init_method=f"file://{self.file_name}",
|
|
)
|
|
|
|
# set device for nccl pg for collectives
|
|
if BACKEND == "nccl":
|
|
torch.cuda.set_device(self.rank)
|
|
|
|
@with_comms()
|
|
def test_all_gather_object(self):
|
|
output = [None] * dist.get_world_size()
|
|
dist.all_gather_object(
|
|
object_list=output,
|
|
obj=self.rank)
|
|
|
|
for i, v in enumerate(output):
|
|
self.assertEqual(i, v, f"rank: {self.rank}")
|
|
|
|
@with_comms()
|
|
def test_gather_object(self):
|
|
output = [None] * dist.get_world_size() if self.rank == 0 else None
|
|
dist.gather_object(
|
|
obj=self.rank,
|
|
object_gather_list=output)
|
|
|
|
if self.rank == 0:
|
|
for i, v in enumerate(output):
|
|
self.assertEqual(i, v, f"rank: {self.rank}")
|
|
|
|
|
|
@with_comms()
|
|
def test_broadcast_object_list(self):
|
|
val = 99 if self.rank == 0 else None
|
|
object_list = [val] * dist.get_world_size()
|
|
# TODO test with broadcast_object_list's device argument
|
|
dist.broadcast_object_list(object_list=object_list)
|
|
|
|
self.assertEqual(99, object_list[0])
|
|
|
|
@with_comms()
|
|
def test_scatter_object_list(self):
|
|
input_list = list(range(dist.get_world_size())) if self.rank == 0 else None
|
|
output_list = [None]
|
|
dist.scatter_object_list(
|
|
scatter_object_output_list=output_list,
|
|
scatter_object_input_list=input_list)
|
|
|
|
self.assertEqual(self.rank, output_list[0])
|
|
|
|
# Test Object Collectives With Sub Pg
|
|
|
|
def setup_sub_pg(self):
|
|
rank = dist.get_rank()
|
|
base_rank = rank - (rank % 2)
|
|
ranks = [base_rank, base_rank + 1]
|
|
my_pg = dist.new_group(ranks, use_local_synchronization=True)
|
|
return rank, ranks, my_pg
|
|
|
|
@with_comms()
|
|
def test_subpg_scatter_object(self):
|
|
rank, ranks, my_pg = self.setup_sub_pg()
|
|
out_list = [None]
|
|
dist.scatter_object_list(out_list, ranks, src=ranks[0], group=my_pg)
|
|
self.assertEqual(rank, out_list[0])
|
|
|
|
@with_comms()
|
|
def test_subpg_all_gather_object(self):
|
|
rank, ranks, my_pg = self.setup_sub_pg()
|
|
out_list = [None] * len(ranks)
|
|
dist.all_gather_object(out_list, rank, group=my_pg)
|
|
self.assertEqual(ranks, out_list)
|
|
|
|
@with_comms()
|
|
def test_subpg_gather_object(self):
|
|
rank, ranks, my_pg = self.setup_sub_pg()
|
|
out_list = [None] * len(ranks) if rank == ranks[0] else None
|
|
dist.gather_object(rank, out_list, dst=ranks[0], group=my_pg)
|
|
if rank == ranks[0]:
|
|
self.assertEqual(ranks, out_list)
|
|
|
|
@with_comms()
|
|
def test_subpg_broadcast_object(self):
|
|
rank, ranks, my_pg = self.setup_sub_pg()
|
|
out_list = [None]
|
|
if rank == ranks[0]:
|
|
out_list[0] = rank
|
|
dist.broadcast_object_list(out_list, src=ranks[0], group=my_pg)
|
|
self.assertEqual(ranks[0], out_list[0])
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|