import sys import tempfile import unittest import torch import torch.distributed as c10d import torch.multiprocessing as mp from common_cuda import TEST_MULTIGPU from common_utils import TestCase, load_tests, run_tests from common_utils import NO_MULTIPROCESSING_SPAWN # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests if not c10d.is_available(): print('c10d not available, skipping tests') sys.exit(0) if NO_MULTIPROCESSING_SPAWN: print('spawn not available, skipping tests') sys.exit(0) NO_NCCL = not hasattr(c10d, "ProcessGroupNCCL") class ProcessGroupShareTensorTest(TestCase): world_size = 2 @classmethod def opts(cls, threads=2): opts = c10d.ProcessGroupGloo.Options() opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] opts.timeout = 5.0 opts.threads = threads return opts @classmethod def _init_pg_gloo(cls, rank, filename, world_size): store = c10d.FileStore(filename, world_size) return c10d.ProcessGroupGloo( store, rank, world_size, ProcessGroupShareTensorTest.opts()) @classmethod def _init_pg_nccl(cls, rank, filename, world_size): store = c10d.FileStore(filename, world_size) return c10d.ProcessGroupNCCL(store, rank, world_size) def _test_multiprocess(self, f, shared_tensors, init_pg, n_output): ws = self.world_size # file store will delete the test file on destruction file = tempfile.NamedTemporaryFile(delete=False) ctx = mp.get_context('spawn') c2p = ctx.Queue(2) p2c = ctx.Queue(2) ps = [] for i in range(ws): p = ctx.Process( target=f, args=(i, file.name, shared_tensors, ws, init_pg, c2p, p2c)) p.start() ps.append(p) for _ in range(ws * n_output): pid, expected, result = c2p.get() self.assertEqual( expected, result, ( "Expect rank {} to receive tensor {} but got {}." ).format(pid, expected, result) ) for _ in range(ws): p2c.put(0) for p in ps: p.join(2) # Why classmethod? multiprocessing cannot pickle TestCase subclass when in # spawn mode. See https://bugs.python.org/issue33884. @classmethod def _test_broadcast_process( cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c): pg = init_pg(rank, filename, world_size) xs = [shared_tensors[rank]] pg.broadcast(xs).wait() c2p.put((rank, torch.zeros(2, 2), xs[0].to("cpu"))) p2c.get() @unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") def test_shared_broadcast_gloo(self): self._test_multiprocess( ProcessGroupShareTensorTest._test_broadcast_process, [torch.ones(2, 2).to(i) * i for i in range(self.world_size)], ProcessGroupShareTensorTest._init_pg_gloo, 1) @unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") @unittest.skipIf(NO_NCCL, "NCCL needed") def test_shared_broadcast_nccl(self): self._test_multiprocess( ProcessGroupShareTensorTest._test_broadcast_process, [torch.ones(2, 2).to(i) * i for i in range(self.world_size)], ProcessGroupShareTensorTest._init_pg_nccl, 1) @classmethod def _test_allreduce_process( cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c): pg = init_pg(rank, filename, world_size) xs = [shared_tensors[rank]] pg.allreduce(xs, op=c10d.ReduceOp.SUM).wait() c2p.put((rank, torch.ones(2, 2) * 2, xs[0].to("cpu"))) p2c.get() @unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") def test_shared_allreduce_gloo(self): self._test_multiprocess( ProcessGroupShareTensorTest._test_allreduce_process, [torch.ones(2, 2).to(i) for i in range(self.world_size)], ProcessGroupShareTensorTest._init_pg_gloo, 1) @unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") @unittest.skipIf(NO_NCCL, "NCCL needed") def test_shared_allreduce_nccl(self): self._test_multiprocess( ProcessGroupShareTensorTest._test_allreduce_process, [torch.ones(2, 2).to(i) for i in range(self.world_size)], ProcessGroupShareTensorTest._init_pg_nccl, 1) @classmethod def _test_reduce_process( cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c): pg = init_pg(rank, filename, world_size) x = shared_tensors[rank] pg.reduce(x, root=0, op=c10d.ReduceOp.SUM).wait() if rank == 0: c2p.put((rank, torch.ones(2, 2) * 2, x.to("cpu"))) else: c2p.put((rank, torch.ones(2, 2), x.to("cpu"))) p2c.get() @unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") @unittest.skipIf(NO_NCCL, "NCCL needed") def test_shared_reduce_nccl(self): self._test_multiprocess( ProcessGroupShareTensorTest._test_reduce_process, [torch.ones(2, 2).to(i) for i in range(self.world_size)], ProcessGroupShareTensorTest._init_pg_nccl, 1) @classmethod def _test_allgather_process( cls, rank, filename, shared_tensors, world_size, init_pg, c2p, p2c): pg = init_pg(rank, filename, world_size) xs = [shared_tensors[rank]] ys = [[torch.zeros_like(xs[0]) for i in range(world_size)]] pg.allgather(ys, xs).wait() for i in range(world_size): c2p.put((rank, torch.ones(2, 2) * i, ys[0][i].to("cpu"))) p2c.get() @unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") def test_shared_allgather_gloo(self): self._test_multiprocess( ProcessGroupShareTensorTest._test_allgather_process, [torch.ones(2, 2).to(i) * i for i in range(self.world_size)], ProcessGroupShareTensorTest._init_pg_gloo, self.world_size) @unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") @unittest.skipIf(NO_NCCL, "NCCL needed") def test_shared_allgather_nccl(self): self._test_multiprocess( ProcessGroupShareTensorTest._test_allgather_process, [torch.ones(2, 2).to(i) * i for i in range(self.world_size)], ProcessGroupShareTensorTest._init_pg_nccl, self.world_size) @classmethod def _test_allgather_chunk_process( cls, rank, filename, shared_tensor, world_size, init_pg, c2p, p2c): pg = init_pg(rank, filename, world_size) chunks = torch.chunk(shared_tensor, world_size, dim=0) x = chunks[rank] ys = [torch.zeros_like(x) for _ in range(world_size)] pg.allgather(ys, x).wait() c2p.put((rank, chunks[0].to("cpu"), ys[0].to("cpu"))) c2p.put((rank, chunks[1].to("cpu"), ys[1].to("cpu"))) p2c.get() @unittest.skipIf(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") def test_shared_allgather_chunk_gloo(self): self._test_multiprocess( ProcessGroupShareTensorTest._test_allgather_chunk_process, torch.tensor(range(4)).reshape(2, 2), ProcessGroupShareTensorTest._init_pg_gloo, self.world_size) if __name__ == '__main__': run_tests()